diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h @@ -13,6 +13,8 @@ #ifndef MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_ #define MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_ +#include "mlir/Support/LLVM.h" + namespace mlir { // Forward declarations. @@ -32,8 +34,15 @@ /// operands that have been legalized by the conversion framework. This can only /// be done if the branch operation implements the BranchOpInterface. Only /// needed for partial conversions. -void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, - TypeConverter &converter); +/// +/// If for some branch ops, we need to convert/legalize only a sub-set of the +/// op's operands, such filtering behavior can be specified in +/// branchOpsOperandConversionFilter where an op is mapped to the subset of its +/// operands that need to be converted. +void populateBranchOpInterfaceTypeConversionPattern( + RewritePatternSet &patterns, TypeConverter &converter, + const DenseMap<Operation *, DenseSet<int>> + *branchOpsOperandConversionFilter = nullptr); /// Return true if op is a BranchOpInterface op whose operands are all legal /// according to converter. diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -498,8 +498,13 @@ /// Convert the types of block arguments within the given region except for /// the entry region. This replaces each non-entry block with a new block /// containing the updated signature. - LogicalResult convertNonEntryRegionTypes(Region *region, - TypeConverter &converter); + /// + /// If special conversion behavior is needed for the non-entry blocks (for + /// example, we need to convert only a subset of a BB arguments), such + /// behavior can be specified in blockConversions. + LogicalResult convertNonEntryRegionTypes( + Region *region, TypeConverter &converter, + SmallVectorImpl<TypeConverter::SignatureConversion> *blockConversions); /// Replace all the uses of the block argument `from` with value `to`. void replaceUsesOfBlockArgument(BlockArgument from, Value to); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -15,6 +15,8 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" +#include <deque> #include <iterator> #include <memory> @@ -39,13 +41,21 @@ /// Defines the criteria a TensorType must follow in order to be considered /// "detensorable". /// -/// NOTE: For now, only 0-D are supported. +/// NOTE: For now, only 0-D tensors are supported. /// /// Returns true if tensorType can be detensored. bool canBeDetensored(TensorType tensorType) { return tensorType.hasRank() && tensorType.getRank() == 0; } +bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) { + GenericOp genericOp = dyn_cast_or_null<GenericOp>(op); + return genericOp && llvm::all_of(genericOp.getShapedOperandTypes(), + [&](ShapedType shapedType) { + return !typeConverter.isLegal(shapedType); + }); +} + /// A conversion patttern for detensoring `linalg.generic` ops. class DetensorizeGenericOp : public OpConversionPattern<GenericOp> { public: @@ -81,17 +91,36 @@ /// A conversion pattern for detensoring internal (non-entry) blocks within a /// function. struct FunctionNonEntryBlockConversion : public ConversionPattern { - FunctionNonEntryBlockConversion(StringRef functionLikeOpName, - MLIRContext *ctx, TypeConverter &converter) - : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {} + FunctionNonEntryBlockConversion( + StringRef functionLikeOpName, MLIRContext *ctx, TypeConverter &converter, + DenseMap<Block *, DenseSet<int>> blockArgumentDetensoring) + : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx), + blockArgumentDetensoring(blockArgumentDetensoring) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { rewriter.startRootUpdate(op); + Region ®ion = mlir::impl::getFunctionBody(op); + SmallVector<TypeConverter::SignatureConversion, 2> conversions; + + for (Block &block : llvm::drop_begin(region, 1)) { + conversions.emplace_back(block.getNumArguments()); + TypeConverter::SignatureConversion &back = conversions.back(); + DenseSet<int> blockArgumentDetensoringFilter = + blockArgumentDetensoring.lookup(&block); + + for (unsigned int idx = 0; idx < block.getNumArguments(); ++idx) { + if (blockArgumentDetensoringFilter.count(idx)) + back.addInputs(idx, {getTypeConverter()->convertType( + block.getArgumentTypes()[idx])}); + else + back.addInputs(idx, {block.getArgumentTypes()[idx]}); + } + } - if (failed(rewriter.convertNonEntryRegionTypes( - &mlir::impl::getFunctionBody(op), *typeConverter))) { + if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter, + &conversions))) { rewriter.cancelRootUpdate(op); return failure(); } @@ -99,6 +128,9 @@ rewriter.finalizeRootUpdate(op); return success(); } + +private: + const DenseMap<Block *, DenseSet<int>> blockArgumentDetensoring; }; class DetensorizeTypeConverter : public TypeConverter { @@ -160,46 +192,291 @@ /// @see LinalgDetensorize in Linalg/Passes.td for more details. struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> { + LinalgDetensorize() = default; + LinalgDetensorize(const LinalgDetensorize &pass) {} + + class CostModel { + public: + virtual ~CostModel() = default; + + /// A cost model algorithm computes the following outputs: + /// + /// - detensorableLinalgOps: the list of linalg ops that should be + /// detensored. + /// + /// - detensorableBranchOps: a map whose keys are branch ops and whose + /// values are operand indices for such keys. The set of operand indices + /// corresponding to a branch op specify which sub-set of the branch's + /// operands should be detensored (i.e. converted by typeConverter). + /// + /// - blockArgumentDetensoring: since the operands and results of detensored + /// lingal ops can cross the BB boundary (e.g. a linalg op's input can come + /// from a BB argument and a linalg op's output can be passed to successor + /// BBs), we need to maintain the sub-set of arguments that should be + /// detensored (i.e. converted by typeConverter) for each affected BB. + /// + /// Example: + /// + /// For the following snippet: + /// ... + /// ^bb1(%6: tensor<i32>, %9: tensor<i32>): + /// %7 = linalg.init_tensor [] : tensor<i32> + /// %8 = linalg.generic #attrs + /// ins(%6, %6 : tensor<i32>, tensor<i32>) + /// outs(%7 : tensor<i32>) { + /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): + /// %9 = addi %arg0, %arg1 : i32 + /// linalg.yield %9 : i32 + /// } -> tensor<i32> + /// %10 = "some.op"(%9) + /// br ^bb2(%8 : tensor<i32>) + /// ... + /// + /// if the cost model decides that the linalg.generic op should be + /// detensored, then: + /// - detensorableLinalgOps should be = {linalg.generic{add}}. + /// - detensorableBranchOps should be = {bb2 -> {0}}. + /// - blockArgumentDetensoring should be = {bb1 -> {0}, bb2 -> {0}}. + virtual void + compute(FuncOp func, DetensorizeTypeConverter typeConverter, + DenseSet<Operation *> &detensorableLinalgOps, + DenseMap<Operation *, DenseSet<int>> &detensorableBranchOps, + DenseMap<Block *, DenseSet<int>> &blockArgumentDetensoring) = 0; + }; + + class PureControlFlowDetectionModel : public CostModel { + void compute( + FuncOp func, DetensorizeTypeConverter typeConverter, + DenseSet<Operation *> &detensorableLinalgOps, + DenseMap<Operation *, DenseSet<int>> &detensorableBranchOps, + DenseMap<Block *, DenseSet<int>> &blockArgumentDetensoring) override { + // TODO The following code is implemented with loops in mind. We might + // need to add support for if conditions later on. + + DenseSet<Operation *> workList; + // 1. Find which detensorable ops are involved in control-flow (i.e. + // they produce tensors that are then used in a cond_br's condition). + func.walk([&](CondBranchOp condBr) { + auto *chainOp = condBr.condition().getDefiningOp(); + + while (chainOp && !dyn_cast<GenericOp>(chainOp)) { + if (chainOp->getNumOperands() != 1) + break; + + chainOp = chainOp->getOperand(0).getDefiningOp(); + } + + if (!shouldBeDetensored(chainOp, typeConverter)) + return; + + workList.insert(chainOp); + }); + + // 2. Discover other detensorable ops by walking the def-use chain + // backwards starting from the detensorable ops currently on the + // workList. + while (!workList.empty()) { + GenericOp detensorableOp = cast<GenericOp>(*workList.begin()); + detensorableLinalgOps.insert(detensorableOp); + workList.erase(workList.begin()); + + // Discover where the detensorableOp's operands come from. + for (Value operand : detensorableOp.inputs()) + if (!discoverDetensorableComponent( + operand, typeConverter, workList, detensorableLinalgOps, + detensorableBranchOps, blockArgumentDetensoring)) { + // TODO For now we assume there is one opportunity for detensoring + // in a function. This can be extended to support multiple separate + // components in a single function. + detensorableLinalgOps.clear(); + detensorableBranchOps.clear(); + blockArgumentDetensoring.clear(); + return; + } + } + } + + private: + bool discoverDetensorableComponent( + Value operand, TypeConverter typeConverter, + DenseSet<Operation *> &workList, + const DenseSet<Operation *> &detensorableLinalgOps, + DenseMap<Operation *, DenseSet<int>> &detensorableBranchOps, + DenseMap<Block *, DenseSet<int>> &blockArgumentDetensoring) { + auto *definingOp = operand.getDefiningOp(); + + if (definingOp) { + if (comesFromElements(definingOp)) + return true; + + if (!shouldBeDetensored(definingOp, typeConverter)) + return false; + + if (!workList.count(definingOp) && + !detensorableLinalgOps.count(definingOp)) + workList.insert(definingOp); + + return true; + } + + BlockArgument blockArgument = operand.cast<BlockArgument>(); + Block *ownerBlock = blockArgument.getOwner(); + + if (&*ownerBlock->getParent()->begin() == ownerBlock) + return true; + + blockArgumentDetensoring[ownerBlock].insert(blockArgument.getArgNumber()); + + for (PredecessorIterator pred = ownerBlock->pred_begin(); + pred != ownerBlock->pred_end(); ++pred) { + BranchOpInterface terminator = + dyn_cast<BranchOpInterface>((*pred)->getTerminator()); + auto ownerBlockOperands = + terminator.getSuccessorOperands(pred.getSuccessorIndex()); + + // TODO Add a test where the same operand is passed more than once to + // the same block. + if (!ownerBlockOperands || ownerBlockOperands->empty()) + continue; + + auto operand = + ownerBlockOperands.getValue()[blockArgument.getArgNumber()]; + + for (int idx = ownerBlockOperands->getBeginOperandIndex(), + eidx = idx + ownerBlockOperands->size(); + idx < eidx; ++idx) + if (terminator->getOperand(idx) == operand) + detensorableBranchOps[terminator].insert(idx); + + if (!discoverDetensorableComponent( + operand, typeConverter, workList, detensorableLinalgOps, + detensorableBranchOps, blockArgumentDetensoring)) { + + return false; + } + } + + return true; + } + + bool comesFromElements(Operation *op) { + while (op && !dyn_cast<tensor::FromElementsOp>(op)) { + if (op->getNumOperands() > 1) + return false; + + op = op->getOperand(0).getDefiningOp(); + } + + return op; + } + }; + + /// Detensorize everything that can detensored. + class AggressiveDetensoringModel : public CostModel { + void compute( + FuncOp func, DetensorizeTypeConverter typeConverter, + DenseSet<Operation *> &detensorableLinalgOps, + DenseMap<Operation *, DenseSet<int>> &detensorableBranchOps, + DenseMap<Block *, DenseSet<int>> &blockArgumentDetensoring) override { + func.walk([&](GenericOp genericOp) { + if (shouldBeDetensored(genericOp, typeConverter)) + detensorableLinalgOps.insert(genericOp); + }); + + func.walk([&](BranchOpInterface brOp) { + DenseSet<int> brOpOperandDetensoring; + + for (int p = 0, e = brOp->getBlock()->getNumSuccessors(); p < e; ++p) { + auto successorOperands = brOp.getSuccessorOperands(p); + Block *successor = brOp->getSuccessor(p); + + if (!successorOperands.hasValue()) + break; + + for (int idx = successorOperands->getBeginOperandIndex(), + eidx = idx + successorOperands->size(); + idx < eidx; ++idx) { + brOpOperandDetensoring.insert(idx); + blockArgumentDetensoring[successor].insert( + idx - successorOperands->getBeginOperandIndex()); + } + } + + detensorableBranchOps.try_emplace(brOp, + std::move(brOpOperandDetensoring)); + }); + } + }; + void runOnFunction() override { - auto *context = &getContext(); + MLIRContext *context = &getContext(); DetensorizeTypeConverter typeConverter; RewritePatternSet patterns(context); ConversionTarget target(*context); + DenseSet<Operation *> detensorableLinalgOps; + DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; + DenseMap<Block *, DenseSet<int>> blockArgumentDetensoring; - target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) { - // If any of the operands or results cannot be detensored (i.e. they are - // all legal according the DetensorizeTypeConverter), the op is considered - // legal and won't be detensored. - return llvm::any_of(op.getShapedOperandTypes(), - [&](ShapedType shapedType) { - return typeConverter.isLegal(shapedType); - }); - }); + std::unique_ptr<CostModel> costModel; + + if (aggressiveMode.getValue()) + costModel = std::make_unique<AggressiveDetensoringModel>(); + else + costModel = std::make_unique<PureControlFlowDetectionModel>(); + + costModel->compute(getFunction(), typeConverter, detensorableLinalgOps, + detensorableBranchOps, blockArgumentDetensoring); + + target.addDynamicallyLegalOp<GenericOp>( + [&](GenericOp op) { return !detensorableLinalgOps.count(op); }); target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { - // A function is legal if all of its non-entry blocks are legal. We don't - // legalize the entry block (i.e. the function's signature) since - // detensoring can't happen along external calling convention boundaries, - // which we conservatively approximate as all function signatures. + // A function is legal if all of its non-entry blocks are legal. We + // don't legalize the entry block (i.e. the function's signature) since + // detensoring can't happen along external calling convention + // boundaries, which we conservatively approximate as all function + // signatures. return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) { - return typeConverter.isLegal(block.getArgumentTypes()); + if (blockArgumentDetensoring.count(&block) && + llvm::any_of(blockArgumentDetensoring[&block], [&](int idx) { + return !typeConverter.isLegal(block.getArgumentTypes()[idx]); + })) { + return false; + } + return true; }); }); target.markUnknownOpDynamicallyLegal([&](Operation *op) { - return isNotBranchOpInterfaceOrReturnLikeOp(op) || - isLegalForBranchOpInterfaceTypeConversionPattern(op, - typeConverter) || - isLegalForReturnOpTypeConversionPattern( - op, typeConverter, /*returnOpAlwaysLegal*/ true); + if (isNotBranchOpInterfaceOrReturnLikeOp(op) || + isLegalForReturnOpTypeConversionPattern(op, typeConverter, + /*returnOpAlwaysLegal*/ true)) + return true; + + if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { + if (!detensorableBranchOps.count(branchOp)) + return true; + + for (auto operandIdx : detensorableBranchOps[branchOp]) + if (!typeConverter.isLegal( + branchOp->getOperand(operandIdx).getType())) + return false; + + return true; + } + + return false; }); - patterns.add<DetensorizeGenericOp>(typeConverter, context); - patterns.add<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(), - context, typeConverter); - // Since non-entry block arguments get detensorized, we also need to update - // the control flow inside the function to reflect the correct types. - populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + patterns.insert<DetensorizeGenericOp>(typeConverter, context); + patterns.insert<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(), + context, typeConverter, + blockArgumentDetensoring); + // Since non-entry block arguments get detensorized, we also need to + // update the control flow inside the function to reflect the correct + // types. + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, + &detensorableBranchOps); if (failed(applyFullConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); @@ -210,6 +487,11 @@ std::move(canonPatterns)))) signalPassFailure(); } + + Option<bool> aggressiveMode{ + *this, "aggressive-mode", + llvm::cl::desc("Detensorize all ops that qualify for detensoring along " + "with branch operands and basic-block arguments.")}; }; } // namespace diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp @@ -52,27 +52,44 @@ using OpInterfaceConversionPattern< BranchOpInterface>::OpInterfaceConversionPattern; + BranchOpInterfaceTypeConversion(TypeConverter &typeConverter, + MLIRContext *ctx, + const DenseMap<Operation *, DenseSet<int>> + *branchOpsOperandConversionFilter) + : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1), + branchOpsOperandConversionFilter(branchOpsOperandConversionFilter) {} + LogicalResult matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { + DenseSet<int> opOperandFilter; + + if (branchOpsOperandConversionFilter) + opOperandFilter = branchOpsOperandConversionFilter->lookup(op); + // For a branch operation, only some operands go to the target blocks, so // only rewrite those. SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end()); for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors(); succIdx < succEnd; ++succIdx) { auto successorOperands = op.getSuccessorOperands(succIdx); - if (!successorOperands) + if (!successorOperands || successorOperands->empty()) continue; + for (int idx = successorOperands->getBeginOperandIndex(), eidx = idx + successorOperands->size(); idx < eidx; ++idx) { - newOperands[idx] = operands[idx]; + if (!branchOpsOperandConversionFilter || opOperandFilter.count(idx)) + newOperands[idx] = operands[idx]; } } rewriter.updateRootInPlace( op, [newOperands, op]() { op->setOperands(newOperands); }); return success(); } + +private: + const DenseMap<Operation *, DenseSet<int>> *branchOpsOperandConversionFilter; }; } // end anonymous namespace @@ -98,9 +115,11 @@ } // end anonymous namespace void mlir::populateBranchOpInterfaceTypeConversionPattern( - RewritePatternSet &patterns, TypeConverter &typeConverter) { - patterns.add<BranchOpInterfaceTypeConversion>(typeConverter, - patterns.getContext()); + RewritePatternSet &patterns, TypeConverter &typeConverter, + const DenseMap<Operation *, DenseSet<int>> + *branchOpsOperandConversionFilter) { + patterns.insert<BranchOpInterfaceTypeConversion>( + typeConverter, patterns.getContext(), branchOpsOperandConversionFilter); } bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern( diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -495,8 +495,17 @@ // to pack the new values. For 1->1 mappings, if there is no materialization // provided, use the argument directly instead. auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); - Value newArg = converter.materializeArgumentConversion( - rewriter, origArg.getLoc(), origArg.getType(), replArgs); + Value newArg; + + // If this is a 1->1 mapping and the types of new and replacement arguments + // match (i.e. it's an idnetity map), then the argument is mapped to its + // original type. + if (replArgs.size() == 1 && replArgs[0].getType() == origArg.getType()) + newArg = replArgs[0]; + else + newArg = converter.materializeArgumentConversion( + rewriter, origArg.getLoc(), origArg.getType(), replArgs); + if (!newArg) { assert(replArgs.size() == 1 && "couldn't materialize the result of 1->N conversion"); @@ -754,8 +763,9 @@ TypeConverter::SignatureConversion *entryConversion); /// Convert the types of non-entry block arguments within the given region. - LogicalResult convertNonEntryRegionTypes(Region *region, - TypeConverter &converter); + LogicalResult convertNonEntryRegionTypes( + Region *region, TypeConverter &converter, + SmallVectorImpl<TypeConverter::SignatureConversion> *blockConversions); //===--------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -1164,7 +1174,7 @@ if (region->empty()) return nullptr; - if (failed(convertNonEntryRegionTypes(region, converter))) + if (failed(convertNonEntryRegionTypes(region, converter, nullptr))) return failure(); FailureOr<Block *> newEntry = @@ -1173,14 +1183,18 @@ } LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( - Region *region, TypeConverter &converter) { + Region *region, TypeConverter &converter, + SmallVectorImpl<TypeConverter::SignatureConversion> *blockConversions) { argConverter.setConverter(region, &converter); if (region->empty()) return success(); // Convert the arguments of each block within the region. + int blockIdx = 0; for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) - if (failed(convertBlockSignature(&block, converter))) + if (failed(convertBlockSignature( + &block, converter, + blockConversions ? &(*blockConversions)[blockIdx++] : nullptr))) return failure(); return success(); } @@ -1351,8 +1365,9 @@ } LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes( - Region *region, TypeConverter &converter) { - return impl->convertNonEntryRegionTypes(region, converter); + Region *region, TypeConverter &converter, + SmallVectorImpl<TypeConverter::SignatureConversion> *blockConversions) { + return impl->convertNonEntryRegionTypes(region, converter, blockConversions); } void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-opt %s -linalg-detensorize=aggressive-mode | FileCheck %s -check-prefix=DET-ALL +// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s -check-prefix=DET-CF + +#map0 = affine_map<() -> ()> + +#attrs = { + indexing_maps = [#map0, #map0, #map0], + iterator_types = [] +} + +func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {} { + br ^bb1(%farg0 : tensor<i32>) + +^bb1(%0: tensor<i32>): // 2 preds: ^bb0, ^bb2 + %1 = linalg.init_tensor [] : tensor<i1> + %2 = linalg.generic #attrs + ins(%0, %farg1 : tensor<i32>, tensor<i32>) + outs(%1 : tensor<i1>) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors + %8 = cmpi slt, %arg0, %arg1 : i32 + linalg.yield %8 : i1 + } -> tensor<i1> + %3 = tensor.extract %2[] : tensor<i1> + cond_br %3, ^bb2(%0 : tensor<i32>), ^bb3(%0 : tensor<i32>) + +^bb2(%4: tensor<i32>): // pred: ^bb1 + %5 = linalg.init_tensor [] : tensor<i32> + %6 = linalg.generic #attrs + ins(%4, %4 : tensor<i32>, tensor<i32>) + outs(%5 : tensor<i32>) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors + %8 = addi %arg0, %arg1 : i32 + linalg.yield %8 : i32 + } -> tensor<i32> + br ^bb1(%6 : tensor<i32>) + +^bb3(%7: tensor<i32>): // pred: ^bb1 + return %7 : tensor<i32> +} + +// Test aggresively detensoring all detensorable ops. +// +// DET-ALL-LABEL: func @main +// DET-ALL-SAME: (%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>) +// DET-ALL: tensor.extract {{.*}} +// DET-ALL: br ^[[bb1:.*]](%{{.*}} : i32) +// DET-ALL: ^[[bb1]](%{{.*}}: i32) +// DET-ALL: cmpi slt, {{.*}} +// DET-ALL: cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32) +// DET-ALL: ^[[bb2]](%{{.*}}: i32) +// DET-ALL: addi {{.*}} +// DET-ALL: br ^[[bb1]](%{{.*}} : i32) +// DET-ALL: ^[[bb3]](%{{.*}}: i32) +// DET-ALL: tensor.from_elements {{.*}} +// DET-ALL: linalg.tensor_reshape {{.*}} +// DET-ALL: return %{{.*}} : tensor<i32> + +// Test detensoring only ops involed in control-flow. +// +// DET-CF-LABEL: func @main +// DET-CF-SAME: (%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>) +// DET-CF: tensor.extract {{.*}} +// DET-CF: br ^[[bb1:.*]](%{{.*}} : i32) +// DET-CF: ^[[bb1]](%{{.*}}: i32) +// DET-CF-DAG tensor.from_elements {{.*}} +// DET-CF-DAG: linalg.tensor_reshape {{.*}} +// DET-CF-DAG: cmpi slt, {{.*}} +// DET-CF: cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : tensor<i32>) +// DET-CF: ^[[bb2]](%{{.*}}: i32) +// DET-CF: addi {{.*}} +// DET-CF: br ^[[bb1]](%{{.*}} : i32) +// DET-CF: ^[[bb3]](%{{.*}}: tensor<i32>) +// DET-CF: return %{{.*}} : tensor<i32> diff --git a/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir b/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir @@ -0,0 +1,111 @@ +// RUN: mlir-opt %s -linalg-detensorize=aggressive-mode | FileCheck %s -check-prefix=DET-ALL +// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s -check-prefix=DET-CF + +#map0 = affine_map<() -> ()> +#map1 = affine_map<(i) -> ()> +#map2 = affine_map<(i) -> (i)> + +#attrs = { + indexing_maps = [#map0, #map0, #map0], + iterator_types = [] +} + +#sum_reduction_attrs = { + indexing_maps = [#map2, #map1], + iterator_types = ["reduction"] +} + + +#broadcast_attrs = { + indexing_maps = [#map1, #map2], + iterator_types = ["parallel"] +} + +func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attributes {} { + br ^bb1(%farg0 : tensor<10xi32>) + +^bb1(%0: tensor<10xi32>): // 2 preds: ^bb0, ^bb2 + %1 = linalg.init_tensor [] : tensor<i32> + %2 = linalg.generic #sum_reduction_attrs + ins(%0: tensor<10xi32>) + outs(%1: tensor<i32>) { + ^bb(%a: i32, %x: i32): + %b = addi %x, %a : i32 + linalg.yield %b : i32 + } -> tensor<i32> + + %3 = linalg.init_tensor [] : tensor<i1> + %4 = linalg.generic #attrs + ins(%2, %farg1 : tensor<i32>, tensor<i32>) + outs(%3 : tensor<i1>) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors + %8 = cmpi slt, %arg0, %arg1 : i32 + linalg.yield %8 : i1 + } -> tensor<i1> + %5 = tensor.extract %4[] : tensor<i1> + cond_br %5, ^bb2(%2 : tensor<i32>), ^bb3(%2 : tensor<i32>) + +^bb2(%6: tensor<i32>): // pred: ^bb1 + %7 = linalg.init_tensor [10] : tensor<10xi32> + %9 = linalg.generic #broadcast_attrs + ins(%6: tensor<i32>) + outs(%7: tensor<10xi32>) { + ^bb(%a: i32, %b: i32) : + linalg.yield %a : i32 + } -> tensor<10xi32> + + br ^bb1(%9 : tensor<10xi32>) + +^bb3(%10: tensor<i32>): // pred: ^bb1 + return %10 : tensor<i32> +} + +// Test aggresively detensoring all detensorable ops. +// +// DET-ALL-LABEL: func @main +// DET-ALL-SAME: (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>) +// DET-ALL: br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>) +// DET-ALL: ^[[bb1]](%{{.*}}: tensor<10xi32>) +// DET-ALL: linalg.init_tensor [] : tensor<i32> +// DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) { +// DET-ALL: ^bb0(%{{.*}}: i32, %{{.*}}: i32): // no predecessors +// DET-ALL: %{{.*}} = addi %{{.*}}, %{{.*}} +// DET-ALL: linalg.yield %{{.*}} : i32 +// DET-ALL: } -> tensor<i32> +// DET-ALL: tensor.extract %{{.*}}[] : tensor<i32> +// DET-ALL: tensor.extract %{{.*}}[] : tensor<i32> +// DET-ALL: cmpi slt, %{{.*}}, %{{.*}} : i32 +// DET-ALL: tensor.extract %{{.*}}[] : tensor<i32> +// DET-ALL: tensor.extract %{{.*}}[] : tensor<i32> +// DET-ALL: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32) +// DET-ALL: ^[[bb2]](%{{.*}}: i32) +// DET-ALL: tensor.from_elements %{{.*}} : tensor<1xi32> +// DET-ALL: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32> +// DET-ALL: linalg.init_tensor [10] : tensor<10xi32> +// DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) { +// DET-ALL: ^bb0(%{{.*}}: i32, %{{.*}}: i32): +// DET-ALL: linalg.yield %{{.*}} : i32 +// DET-ALL: } -> tensor<10xi32> +// DET-ALL: br ^[[bb1]](%{{.*}} : tensor<10xi32>) +// DET-ALL: ^[[bb3]](%{{.*}}: i32) +// DET-ALL: tensor.from_elements %{{.*}} : tensor<1xi32> +// DET-ALL: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32> +// DET-ALL: return %{{.*}} : tensor<i32> +// DET-ALL: } + +// Try to detensor pure control-flow. However, that fails since the potential +// detensorable component contains some ops that cannot be detensored. +// +// DET-CF-LABEL: func @main +// DET-CF-SAME: (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>) +// DET-CF: br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>) +// DET-CF: ^bb1(%{{.*}}: tensor<10xi32>) +// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) { +// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<i32>, tensor<i32>) outs(%{{.*}} : tensor<i1>) { +// DET-CF: cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>) +// DET-CF: ^bb2(%{{.*}}: tensor<i32>) +// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) { +// DET-CF: br ^bb1(%{{.*}} : tensor<10xi32>) +// DET-CF: ^bb3(%{{.*}}: tensor<i32>) +// DET-CF: return %{{.*}} : tensor<i32> +// DET-CF: } diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s + +#map0 = affine_map<() -> ()> + +#attrs = { + indexing_maps = [#map0, #map0, #map0], + iterator_types = [] +} + +func @main() -> () attributes {} { + %c0 = constant 0 : i32 + %0 = tensor.from_elements %c0 : tensor<1xi32> + %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32> + %c10 = constant 10 : i32 + %1 = tensor.from_elements %c10 : tensor<1xi32> + %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32> + br ^bb1(%reshaped0 : tensor<i32>) + +^bb1(%2: tensor<i32>): // 2 preds: ^bb0, ^bb2 + %3 = linalg.init_tensor [] : tensor<i1> + %4 = linalg.generic #attrs + ins(%2, %reshaped1 : tensor<i32>, tensor<i32>) + outs(%3 : tensor<i1>) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors + %8 = cmpi slt, %arg0, %arg1 : i32 + linalg.yield %8 : i1 + } -> tensor<i1> + %5 = tensor.extract %4[] : tensor<i1> + cond_br %5, ^bb2(%2 : tensor<i32>), ^bb3 + +^bb2(%6: tensor<i32>): // pred: ^bb1 + %7 = linalg.init_tensor [] : tensor<i32> + %8 = linalg.generic #attrs + ins(%6, %6 : tensor<i32>, tensor<i32>) + outs(%7 : tensor<i32>) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors + %9 = addi %arg0, %arg1 : i32 + linalg.yield %9 : i32 + } -> tensor<i32> + br ^bb1(%8 : tensor<i32>) + +^bb3: // pred: ^bb1 + return +} + +// CHECK-LABEL: func @main +// %c0_i32 = constant 0 : i32 +// %c10_i32 = constant 10 : i32 +// br ^bb1(%c0_i32 : i32) +// ^bb1(%0: i32): // 2 preds: ^bb0, ^bb2 +// %1 = cmpi slt, %0, %c10_i32 : i32 +// cond_br %1, ^bb2(%0 : i32), ^bb3 +// ^bb2(%2: i32): // pred: ^bb1 +// %3 = addi %2, %2 : i32 +// br ^bb1(%3 : i32) +// ^bb3: // pred: ^bb1 +// return +// } diff --git a/mlir/test/Dialect/Linalg/detensorized_0d.mlir b/mlir/test/Dialect/Linalg/detensorized_0d.mlir --- a/mlir/test/Dialect/Linalg/detensorized_0d.mlir +++ b/mlir/test/Dialect/Linalg/detensorized_0d.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize=aggressive-mode | FileCheck %s #map = affine_map<() -> ()> diff --git a/mlir/test/Dialect/Linalg/detensorized_while.mlir b/mlir/test/Dialect/Linalg/detensorized_while.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/detensorized_while.mlir +++ /dev/null @@ -1,53 +0,0 @@ -// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s - -#map0 = affine_map<() -> ()> - -#attrs = { - indexing_maps = [#map0, #map0, #map0], - iterator_types = [] -} - -func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {} { - br ^bb1(%farg0 : tensor<i32>) - -^bb1(%0: tensor<i32>): // 2 preds: ^bb0, ^bb2 - %1 = linalg.init_tensor [] : tensor<i1> - %2 = linalg.generic #attrs - ins(%0, %farg1 : tensor<i32>, tensor<i32>) - outs(%1 : tensor<i1>) { - ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors - %8 = cmpi slt, %arg0, %arg1 : i32 - linalg.yield %8 : i1 - } -> tensor<i1> - %3 = tensor.extract %2[] : tensor<i1> - cond_br %3, ^bb2(%0 : tensor<i32>), ^bb3(%0 : tensor<i32>) - -^bb2(%4: tensor<i32>): // pred: ^bb1 - %5 = linalg.init_tensor [] : tensor<i32> - %6 = linalg.generic #attrs - ins(%4, %4 : tensor<i32>, tensor<i32>) - outs(%5 : tensor<i32>) { - ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors - %8 = addi %arg0, %arg1 : i32 - linalg.yield %8 : i32 - } -> tensor<i32> - br ^bb1(%6 : tensor<i32>) - -^bb3(%7: tensor<i32>): // pred: ^bb1 - return %7 : tensor<i32> -} - -// CHECK-LABEL: func @main -// CHECK-SAME: (%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>) -// CHECK: tensor.extract {{.*}} -// CHECK: br ^[[bb1:.*]](%{{.*}} : i32) -// CHECK: ^[[bb1]](%{{.*}}: i32) -// CHECK: cmpi slt, {{.*}} -// CHECK: cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32) -// CHECK: ^[[bb2]](%{{.*}}: i32) -// CHECK: addi {{.*}} -// CHECK: br ^[[bb1]](%{{.*}} : i32) -// CHECK: ^[[bb3]](%{{.*}}: i32) -// CHECK: tensor.from_elements {{.*}} -// CHECK: linalg.tensor_reshape {{.*}} -// CHECK: return %{{.*}} : tensor<i32>