diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h @@ -13,9 +13,13 @@ #ifndef MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_ #define MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_ +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" + namespace mlir { // Forward declarations. +class BranchOpInterface; class ConversionTarget; class MLIRContext; class Operation; @@ -32,8 +36,15 @@ /// operands that have been legalized by the conversion framework. This can only /// be done if the branch operation implements the BranchOpInterface. Only /// needed for partial conversions. -void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, - TypeConverter &converter); +/// +/// If for some branch ops, we need to convert/legalize only a sub-set of the +/// op's operands, such filtering behavior can be specified in +/// shouldConvertBranchOperand. This callback should return true if branchOp's +/// operand at index idx should be converted. +void populateBranchOpInterfaceTypeConversionPattern( + RewritePatternSet &patterns, TypeConverter &converter, + function_ref + shouldConvertBranchOperand = nullptr); /// Return true if op is a BranchOpInterface op whose operands are all legal /// according to converter. diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -498,8 +498,13 @@ /// Convert the types of block arguments within the given region except for /// the entry region. This replaces each non-entry block with a new block /// containing the updated signature. - LogicalResult convertNonEntryRegionTypes(Region *region, - TypeConverter &converter); + /// + /// If special conversion behavior is needed for the non-entry blocks (for + /// example, we need to convert only a subset of a BB arguments), such + /// behavior can be specified in blockConversions. + LogicalResult convertNonEntryRegionTypes( + Region *region, TypeConverter &converter, + ArrayRef blockConversions); /// Replace all the uses of the block argument `from` with value `to`. void replaceUsesOfBlockArgument(BlockArgument from, Value to); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -39,13 +39,21 @@ /// Defines the criteria a TensorType must follow in order to be considered /// "detensorable". /// -/// NOTE: For now, only 0-D are supported. +/// NOTE: For now, only 0-D tensors are supported. /// /// Returns true if tensorType can be detensored. bool canBeDetensored(TensorType tensorType) { return tensorType.hasRank() && tensorType.getRank() == 0; } +bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) { + GenericOp genericOp = dyn_cast_or_null(op); + return genericOp && llvm::all_of(genericOp.getShapedOperandTypes(), + [&](ShapedType shapedType) { + return !typeConverter.isLegal(shapedType); + }); +} + /// A conversion patttern for detensoring `linalg.generic` ops. class DetensorizeGenericOp : public OpConversionPattern { public: @@ -82,16 +90,35 @@ /// function. struct FunctionNonEntryBlockConversion : public ConversionPattern { FunctionNonEntryBlockConversion(StringRef functionLikeOpName, - MLIRContext *ctx, TypeConverter &converter) - : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {} + MLIRContext *ctx, TypeConverter &converter, + DenseSet blockArgsToDetensor) + : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx), + blockArgsToDetensor(blockArgsToDetensor) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.startRootUpdate(op); + Region ®ion = mlir::impl::getFunctionBody(op); + SmallVector conversions; + + for (Block &block : llvm::drop_begin(region, 1)) { + conversions.emplace_back(block.getNumArguments()); + TypeConverter::SignatureConversion &back = conversions.back(); + + for (BlockArgument blockArgument : block.getArguments()) { + int idx = blockArgument.getArgNumber(); + + if (blockArgsToDetensor.count(blockArgument)) + back.addInputs(idx, {getTypeConverter()->convertType( + block.getArgumentTypes()[idx])}); + else + back.addInputs(idx, {block.getArgumentTypes()[idx]}); + } + } - if (failed(rewriter.convertNonEntryRegionTypes( - &mlir::impl::getFunctionBody(op), *typeConverter))) { + if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter, + conversions))) { rewriter.cancelRootUpdate(op); return failure(); } @@ -99,6 +126,9 @@ rewriter.finalizeRootUpdate(op); return success(); } + +private: + const DenseSet blockArgsToDetensor; }; class DetensorizeTypeConverter : public TypeConverter { @@ -160,46 +190,309 @@ /// @see LinalgDetensorize in Linalg/Passes.td for more details. struct LinalgDetensorize : public LinalgDetensorizeBase { + LinalgDetensorize() = default; + LinalgDetensorize(const LinalgDetensorize &pass) {} + + class CostModel { + public: + virtual ~CostModel() = default; + + /// A cost model algorithm computes the following outputs: + /// + /// - opsToDetensor: the list of linalg ops that should be + /// detensored. + /// + /// - blockArgsToDetensor: since the operands and results of detensored + /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come + /// from a BB argument and a linalg op's output can be passed to successor + /// BBs), we need to maintain the sub-set of arguments that should be + /// detensored (i.e. converted by typeConverter) for each affected BB. + /// + /// Example: + /// + /// For the following snippet: + /// ... + /// ^bb1(%6: tensor, %9: tensor): + /// %7 = linalg.init_tensor [] : tensor + /// %8 = linalg.generic #attrs + /// ins(%6, %6 : tensor, tensor) + /// outs(%7 : tensor) { + /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): + /// %9 = addi %arg0, %arg1 : i32 + /// linalg.yield %9 : i32 + /// } -> tensor + /// %10 = "some.op"(%9) + /// br ^bb2(%8 : tensor) + /// ... + /// + /// if the cost model decides that the linalg.generic op should be + /// detensored, then: + /// - opsToDetensor should be = {linalg.generic{add}}. + /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}. + virtual void compute(FuncOp func, DetensorizeTypeConverter typeConverter, + DenseSet &opsToDetensor, + DenseSet &blockArgsToDetensor) = 0; + + /// From the blockArgsToDetensor set computed by a CostModel + /// implementation, this method computes the corresponding branch op + /// detensoring. The result is a map from a branch op to a subset of indices + /// of its operands. The indices specify which of the branch op's operands + /// should be detensored. + /// + /// For the previous example, this method would compute: {bb2 -> {0}}. + static DenseMap> computeBranchOpDetensoring( + const DenseSet &blockArgsToDetensor) { + DenseMap> detensorableBranchOps; + + for (auto blockArgumentElem : blockArgsToDetensor) { + Block *block = blockArgumentElem.getOwner(); + + for (PredecessorIterator pred = block->pred_begin(); + pred != block->pred_end(); ++pred) { + BranchOpInterface terminator = + dyn_cast((*pred)->getTerminator()); + auto blockOperands = + terminator.getSuccessorOperands(pred.getSuccessorIndex()); + + if (!blockOperands || blockOperands->empty()) + continue; + + detensorableBranchOps[terminator].insert( + blockOperands->getBeginOperandIndex() + + blockArgumentElem.getArgNumber()); + } + } + + return detensorableBranchOps; + } + }; + + /// Detensorize linalg ops involved in control-flow within a function. + /// + /// This model starts from CondBranchOps within a function. For each cond_br, + /// the model then walks the use-def chain for the branch's condition + /// backwards in order to understand where the condition's value comes from. + /// If the condition value is (indirectly) computed by a linalg op that can be + /// detensored, the model then continues walking the use-def chain in order to + /// understand where the linalg op's operands come from. This leads to + /// discovering a "detensoring component". A detensoring component is the set + /// of operations + block arguments that are involved in control-flow AND can + /// be detensored. + /// + /// For examples where this model succeeds to discover a detensoring + /// component, see: + /// - test/Dialect/Linalg/detensorize_while.mlir + /// - test/Dialect/Linalg/detesorize_while_pure_cf.mlir. + /// + /// For an example where this model marks control-flow as "non-detensorable", + /// see: + /// - test/Dialect/Linalg/detensorize_while_failure.mlir + class PureControlFlowDetectionModel : public CostModel { + public: + void compute(FuncOp func, DetensorizeTypeConverter typeConverter, + DenseSet &opsToDetensor, + DenseSet &blockArgsToDetensor) override { + SmallVector workList; + + func.walk( + [&](CondBranchOp condBr) { workList.push_back(condBr.condition()); }); + + DenseSet visitedValues; + DenseSet visitedOps; + + while (!workList.empty()) { + Value currentItem = workList.pop_back_val(); + + if (!visitedValues.insert(currentItem).second) + continue; + + // The current item is defined by a block argument. + if (auto bbarg = currentItem.dyn_cast()) { + BlockArgument currentItemBlockArgument = + currentItem.cast(); + Block *ownerBlock = currentItemBlockArgument.getOwner(); + + // Function arguments are not detensored/converted. + if (&*ownerBlock->getParent()->begin() == ownerBlock) + continue; + + // This inner-block argument is involved in control-flow, it should be + // detensored. + blockArgsToDetensor.insert(currentItemBlockArgument); + + for (PredecessorIterator pred = ownerBlock->pred_begin(); + pred != ownerBlock->pred_end(); ++pred) { + BranchOpInterface terminator = + dyn_cast((*pred)->getTerminator()); + + // TODO: For now, we give up if any of the control-flow components + // in a function is not detensorable. Fix that. + if (!terminator) { + opsToDetensor.clear(); + blockArgsToDetensor.clear(); + return; + } + + auto ownerBlockOperands = + terminator.getSuccessorOperands(pred.getSuccessorIndex()); + + if (!ownerBlockOperands || ownerBlockOperands->empty()) + continue; + + // For each predecessor, add the value it passes to that argument to + // workList to find out how it's computed. + workList.push_back( + ownerBlockOperands + .getValue()[currentItemBlockArgument.getArgNumber()]); + } + + continue; + } + + Operation *currentItemDefiningOp = currentItem.getDefiningOp(); + + if (!visitedOps.insert(currentItemDefiningOp).second) + continue; + + // The current item is computed by a GenericOp. + if (auto genericOp = dyn_cast(currentItemDefiningOp)) { + // The op was encountered already, no need to inspect it again. + if (opsToDetensor.count(genericOp)) + continue; + + // TODO: For now, we give up if any of the control-flow components + // in a function is not detensorable. Fix that. + if (!shouldBeDetensored(genericOp, typeConverter)) { + opsToDetensor.clear(); + blockArgsToDetensor.clear(); + return; + } + + opsToDetensor.insert(genericOp); + + for (Value genericOpOperand : genericOp.inputs()) + workList.push_back(genericOpOperand); + + continue; + } + + // The current item is the result of a FromElemntsOp, it will be + // trivially detensored later as part of canonicalization patterns + // applied at the end of detensoring. + // + // Note: No need to check whether the result type of this op is + // detensorable since if it wasn't we wouldn't reach that point in the + // work list. + if (dyn_cast(currentItemDefiningOp)) + continue; + + // The current item is the result of a scalar op, add all its operands + // to the work list. + if (llvm::all_of( + currentItemDefiningOp->getResultTypes(), + [&](Type resultType) { return resultType.isIntOrFloat(); })) + for (Value scalarOpOperand : currentItemDefiningOp->getOperands()) + workList.push_back(scalarOpOperand); + } + } + }; + + /// Detensorize everything that can detensored. + class AggressiveDetensoringModel : public CostModel { + public: + void compute(FuncOp func, DetensorizeTypeConverter typeConverter, + DenseSet &opsToDetensor, + DenseSet &blockArgsToDetensor) override { + func.walk([&](GenericOp genericOp) { + if (shouldBeDetensored(genericOp, typeConverter)) + opsToDetensor.insert(genericOp); + }); + + for (Block &block : llvm::drop_begin(func.getBody(), 1)) + for (BlockArgument blockArgument : block.getArguments()) + blockArgsToDetensor.insert(blockArgument); + } + }; + void runOnFunction() override { - auto *context = &getContext(); + MLIRContext *context = &getContext(); DetensorizeTypeConverter typeConverter; RewritePatternSet patterns(context); ConversionTarget target(*context); + DenseSet opsToDetensor; + DenseMap> detensorableBranchOps; + DenseSet blockArgsToDetensor; + + if (aggressiveMode.getValue()) { + AggressiveDetensoringModel costModel; + costModel.compute(getFunction(), typeConverter, opsToDetensor, + blockArgsToDetensor); + + } else { + PureControlFlowDetectionModel costModel; + costModel.compute(getFunction(), typeConverter, opsToDetensor, + blockArgsToDetensor); + } - target.addDynamicallyLegalOp([&](GenericOp op) { - // If any of the operands or results cannot be detensored (i.e. they are - // all legal according the DetensorizeTypeConverter), the op is considered - // legal and won't be detensored. - return llvm::any_of(op.getShapedOperandTypes(), - [&](ShapedType shapedType) { - return typeConverter.isLegal(shapedType); - }); - }); + detensorableBranchOps = + CostModel::computeBranchOpDetensoring(blockArgsToDetensor); + + target.addDynamicallyLegalOp( + [&](GenericOp op) { return !opsToDetensor.count(op); }); target.addDynamicallyLegalOp([&](FuncOp op) { - // A function is legal if all of its non-entry blocks are legal. We don't - // legalize the entry block (i.e. the function's signature) since - // detensoring can't happen along external calling convention boundaries, - // which we conservatively approximate as all function signatures. + // A function is legal if all of its non-entry blocks are legal. We + // don't legalize the entry block (i.e. the function's signature) since + // detensoring can't happen along external calling convention + // boundaries, which we conservatively approximate as all function + // signatures. return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) { - return typeConverter.isLegal(block.getArgumentTypes()); + if (llvm::any_of(blockArgsToDetensor, [&](BlockArgument blockArgument) { + return blockArgument.getOwner() == &block && + !typeConverter.isLegal(blockArgument.getType()); + })) { + return false; + } + return true; }); }); target.markUnknownOpDynamicallyLegal([&](Operation *op) { - return isNotBranchOpInterfaceOrReturnLikeOp(op) || - isLegalForBranchOpInterfaceTypeConversionPattern(op, - typeConverter) || - isLegalForReturnOpTypeConversionPattern( - op, typeConverter, /*returnOpAlwaysLegal*/ true); + if (isNotBranchOpInterfaceOrReturnLikeOp(op) || + isLegalForReturnOpTypeConversionPattern(op, typeConverter, + /*returnOpAlwaysLegal*/ true)) + return true; + + if (auto branchOp = dyn_cast(op)) { + if (!detensorableBranchOps.count(branchOp)) + return true; + + for (auto operandIdx : detensorableBranchOps[branchOp]) + if (!typeConverter.isLegal( + branchOp->getOperand(operandIdx).getType())) + return false; + + return true; + } + + return false; }); - patterns.add(typeConverter, context); - patterns.add(FuncOp::getOperationName(), - context, typeConverter); - // Since non-entry block arguments get detensorized, we also need to update - // the control flow inside the function to reflect the correct types. - populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + patterns.insert(typeConverter, context); + patterns.insert(FuncOp::getOperationName(), + context, typeConverter, + blockArgsToDetensor); + // Since non-entry block arguments get detensorized, we also need to + // update the control flow inside the function to reflect the correct + // types. + auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp, + int operandIdx) -> bool { + return detensorableBranchOps.count(branchOp) && + detensorableBranchOps[branchOp].count(operandIdx); + }; + + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, + shouldConvertBranchOperand); if (failed(applyFullConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); @@ -210,6 +503,11 @@ std::move(canonPatterns)))) signalPassFailure(); } + + Option aggressiveMode{ + *this, "aggressive-mode", + llvm::cl::desc("Detensorize all ops that qualify for detensoring along " + "with branch operands and basic-block arguments.")}; }; } // namespace diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp @@ -52,6 +52,12 @@ using OpInterfaceConversionPattern< BranchOpInterface>::OpInterfaceConversionPattern; + BranchOpInterfaceTypeConversion( + TypeConverter &typeConverter, MLIRContext *ctx, + function_ref shouldConvertBranchOperand) + : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1), + shouldConvertBranchOperand(shouldConvertBranchOperand) {} + LogicalResult matchAndRewrite(BranchOpInterface op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { @@ -61,18 +67,23 @@ for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors(); succIdx < succEnd; ++succIdx) { auto successorOperands = op.getSuccessorOperands(succIdx); - if (!successorOperands) + if (!successorOperands || successorOperands->empty()) continue; + for (int idx = successorOperands->getBeginOperandIndex(), eidx = idx + successorOperands->size(); idx < eidx; ++idx) { - newOperands[idx] = operands[idx]; + if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx)) + newOperands[idx] = operands[idx]; } } rewriter.updateRootInPlace( op, [newOperands, op]() { op->setOperands(newOperands); }); return success(); } + +private: + function_ref shouldConvertBranchOperand; }; } // end anonymous namespace @@ -98,9 +109,10 @@ } // end anonymous namespace void mlir::populateBranchOpInterfaceTypeConversionPattern( - RewritePatternSet &patterns, TypeConverter &typeConverter) { - patterns.add(typeConverter, - patterns.getContext()); + RewritePatternSet &patterns, TypeConverter &typeConverter, + function_ref shouldConvertBranchOperand) { + patterns.insert( + typeConverter, patterns.getContext(), shouldConvertBranchOperand); } bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern( diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -495,8 +495,17 @@ // to pack the new values. For 1->1 mappings, if there is no materialization // provided, use the argument directly instead. auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); - Value newArg = converter.materializeArgumentConversion( - rewriter, origArg.getLoc(), origArg.getType(), replArgs); + Value newArg; + + // If this is a 1->1 mapping and the types of new and replacement arguments + // match (i.e. it's an identity map), then the argument is mapped to its + // original type. + if (replArgs.size() == 1 && replArgs[0].getType() == origArg.getType()) + newArg = replArgs[0]; + else + newArg = converter.materializeArgumentConversion( + rewriter, origArg.getLoc(), origArg.getType(), replArgs); + if (!newArg) { assert(replArgs.size() == 1 && "couldn't materialize the result of 1->N conversion"); @@ -754,8 +763,9 @@ TypeConverter::SignatureConversion *entryConversion); /// Convert the types of non-entry block arguments within the given region. - LogicalResult convertNonEntryRegionTypes(Region *region, - TypeConverter &converter); + LogicalResult convertNonEntryRegionTypes( + Region *region, TypeConverter &converter, + ArrayRef blockConversions = {}); //===--------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -1173,15 +1183,30 @@ } LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( - Region *region, TypeConverter &converter) { + Region *region, TypeConverter &converter, + ArrayRef blockConversions) { argConverter.setConverter(region, &converter); if (region->empty()) return success(); // Convert the arguments of each block within the region. - for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) - if (failed(convertBlockSignature(&block, converter))) + int blockIdx = 0; + assert((blockConversions.empty() || + blockConversions.size() == region->getBlocks().size() - 1) && + "expected either to provide no SignatureConversions at all or to " + "provide a SignatureConversion for each non-entry block"); + + for (Block &block : + llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) { + TypeConverter::SignatureConversion *blockConversion = + blockConversions.empty() + ? nullptr + : const_cast( + &blockConversions[blockIdx++]); + + if (failed(convertBlockSignature(&block, converter, blockConversion))) return failure(); + } return success(); } @@ -1351,8 +1376,9 @@ } LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes( - Region *region, TypeConverter &converter) { - return impl->convertNonEntryRegionTypes(region, converter); + Region *region, TypeConverter &converter, + ArrayRef blockConversions) { + return impl->convertNonEntryRegionTypes(region, converter, blockConversions); } void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s + +#map0 = affine_map<() -> ()> + +#attrs = { + indexing_maps = [#map0, #map0, #map0], + iterator_types = [] +} + +func @main() -> (tensor) attributes {} { + %c0 = constant 0 : i32 + %0 = tensor.from_elements %c0 : tensor<1xi32> + %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor + %c10 = constant 10 : i32 + %1 = tensor.from_elements %c10 : tensor<1xi32> + %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor + br ^bb1(%reshaped0 : tensor) + +^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 + %3 = linalg.init_tensor [] : tensor + %4 = linalg.generic #attrs + ins(%2, %reshaped1 : tensor, tensor) + outs(%3 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors + %8 = cmpi slt, %arg0, %arg1 : i32 + linalg.yield %8 : i1 + } -> tensor + %5 = tensor.extract %4[] : tensor + cond_br %5, ^bb2(%2 : tensor), ^bb3(%2 : tensor) + +^bb2(%6: tensor): // pred: ^bb1 + %7 = linalg.init_tensor [] : tensor + %8 = linalg.generic #attrs + ins(%6, %6 : tensor, tensor) + outs(%7 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors + %9 = addi %arg0, %arg1 : i32 + linalg.yield %9 : i32 + } -> tensor + br ^bb3(%8 : tensor) + +^bb3(%10: tensor): // pred: ^bb1 + return %10 : tensor +} + +// CHECK-LABEL: func @main() +// CHECK-NEXT: constant 0 +// CHECK-NEXT: constant 10 +// CHECK-NEXT: br ^[[bb1:.*]](%{{.*}}: i32) +// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32): +// CHECK-NEXT: tensor.from_elements %{{.*}} +// CHECK-NEXT: linalg.tensor_reshape %{{.*}} +// CHECK-NEXT: cmpi slt, %{{.*}}, %{{.*}} +// CHECK-NEXT: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : tensor), ^bb3(%{{.*}} : tensor) +// CHECK-NEXT: ^[[bb2]](%{{.*}}: tensor) +// CHECK-NEXT: linalg.init_tensor +// CHECK-NEXT: linalg.generic +// CHECK-NEXT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32) +// CHECK-NEXT: addi %{{.*}}, %{{.*}} +// CHECK-NEXT: linalg.yield %{{.*}} +// CHECK-NEXT: } -> tensor +// CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : tensor) +// CHECK-NEXT: ^[[bb3]](%{{.*}}: tensor) +// CHECK-NEXT: return %{{.*}} +// CHECK-NEXT: } diff --git a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt %s -linalg-detensorize=aggressive-mode | FileCheck %s -check-prefix=DET-ALL +// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s -check-prefix=DET-CF + + +#map0 = affine_map<() -> ()> + +#attrs = { + indexing_maps = [#map0, #map0, #map0], + iterator_types = [] +} + +func @main(%farg0 : tensor) -> (tensor) attributes {} { + %c10 = constant 10 : i32 + %1 = tensor.from_elements %c10 : tensor<1xi32> + %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor + %3 = linalg.init_tensor [] : tensor + %4 = linalg.generic #attrs + ins(%farg0, %reshaped1 : tensor, tensor) + outs(%3 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): + %8 = cmpi slt, %arg0, %arg1 : i32 + linalg.yield %8 : i1 + } -> tensor + return %4 : tensor +} + + +// DET-ALL-LABEL: func @main(%{{.*}}: tensor) +// DET-ALL-NEXT: constant 10 +// DET-ALL-NEXT: tensor.extract %{{.*}}[] +// DET-ALL-NEXT: cmpi slt, %{{.*}}, %{{.*}} +// DET-ALL-NEXT: tensor.from_elements %{{.*}} +// DET-ALL-NEXT: linalg.tensor_reshape %{{.*}} +// DET-ALL-NEXT: return %{{.*}} : tensor +// DET-ALL-NEXT: } + +// DET-CF-LABEL: func @main(%{{.*}}: tensor) +// DET-CF-NEXT: constant 10 : i32 +// DET-CF-NEXT: tensor.from_elements %{{.*}} +// DET-CF-NEXT: linalg.tensor_reshape %{{.*}} +// DET-CF-NEXT: linalg.init_tensor [] : tensor +// DET-CF-NEXT: linalg.generic +// DET-CF-NEXT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i1) +// DET-CF-NEXT: cmpi slt, %{{.*}}, %{{.*}} +// DET-CF-NEXT: linalg.yield %{{.*}} +// DET-CF-NEXT: } -> tensor +// DET-CF-NEXT: return %{{.*}} +// DET-CF-NEXT: } diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-opt %s -linalg-detensorize=aggressive-mode | FileCheck %s -check-prefix=DET-ALL +// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s -check-prefix=DET-CF + +#map0 = affine_map<() -> ()> + +#attrs = { + indexing_maps = [#map0, #map0, #map0], + iterator_types = [] +} + +func @main(%farg0: tensor, %farg1: tensor) -> tensor attributes {} { + br ^bb1(%farg0 : tensor) + +^bb1(%0: tensor): // 2 preds: ^bb0, ^bb2 + %1 = linalg.init_tensor [] : tensor + %2 = linalg.generic #attrs + ins(%0, %farg1 : tensor, tensor) + outs(%1 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors + %8 = cmpi slt, %arg0, %arg1 : i32 + linalg.yield %8 : i1 + } -> tensor + %3 = tensor.extract %2[] : tensor + cond_br %3, ^bb2(%0 : tensor), ^bb3(%0 : tensor) + +^bb2(%4: tensor): // pred: ^bb1 + %5 = linalg.init_tensor [] : tensor + %6 = linalg.generic #attrs + ins(%4, %4 : tensor, tensor) + outs(%5 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors + %8 = addi %arg0, %arg1 : i32 + linalg.yield %8 : i32 + } -> tensor + br ^bb1(%6 : tensor) + +^bb3(%7: tensor): // pred: ^bb1 + return %7 : tensor +} + +// Test aggresively detensoring all detensorable ops. +// +// DET-ALL-LABEL: func @main +// DET-ALL-SAME: (%{{.*}}: tensor, %{{.*}}: tensor) +// DET-ALL: tensor.extract {{.*}} +// DET-ALL: br ^[[bb1:.*]](%{{.*}} : i32) +// DET-ALL: ^[[bb1]](%{{.*}}: i32) +// DET-ALL: cmpi slt, {{.*}} +// DET-ALL: cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32) +// DET-ALL: ^[[bb2]](%{{.*}}: i32) +// DET-ALL: addi {{.*}} +// DET-ALL: br ^[[bb1]](%{{.*}} : i32) +// DET-ALL: ^[[bb3]](%{{.*}}: i32) +// DET-ALL: tensor.from_elements {{.*}} +// DET-ALL: linalg.tensor_reshape {{.*}} +// DET-ALL: return %{{.*}} : tensor + +// Test detensoring only ops involed in control-flow. +// +// DET-CF-LABEL: func @main +// DET-CF-SAME: (%{{.*}}: tensor, %{{.*}}: tensor) +// DET-CF: tensor.extract {{.*}} +// DET-CF: br ^[[bb1:.*]](%{{.*}} : i32) +// DET-CF: ^[[bb1]](%{{.*}}: i32) +// DET-CF-DAG tensor.from_elements {{.*}} +// DET-CF-DAG: linalg.tensor_reshape {{.*}} +// DET-CF-DAG: cmpi slt, {{.*}} +// DET-CF: cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : tensor) +// DET-CF: ^[[bb2]](%{{.*}}: i32) +// DET-CF: addi {{.*}} +// DET-CF: br ^[[bb1]](%{{.*}} : i32) +// DET-CF: ^[[bb3]](%{{.*}}: tensor) +// DET-CF: return %{{.*}} : tensor diff --git a/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir b/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir @@ -0,0 +1,111 @@ +// RUN: mlir-opt %s -linalg-detensorize=aggressive-mode | FileCheck %s -check-prefix=DET-ALL +// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s -check-prefix=DET-CF + +#map0 = affine_map<() -> ()> +#map1 = affine_map<(i) -> ()> +#map2 = affine_map<(i) -> (i)> + +#attrs = { + indexing_maps = [#map0, #map0, #map0], + iterator_types = [] +} + +#sum_reduction_attrs = { + indexing_maps = [#map2, #map1], + iterator_types = ["reduction"] +} + + +#broadcast_attrs = { + indexing_maps = [#map1, #map2], + iterator_types = ["parallel"] +} + +func @main(%farg0: tensor<10xi32>, %farg1: tensor) -> tensor attributes {} { + br ^bb1(%farg0 : tensor<10xi32>) + +^bb1(%0: tensor<10xi32>): // 2 preds: ^bb0, ^bb2 + %1 = linalg.init_tensor [] : tensor + %2 = linalg.generic #sum_reduction_attrs + ins(%0: tensor<10xi32>) + outs(%1: tensor) { + ^bb(%a: i32, %x: i32): + %b = addi %x, %a : i32 + linalg.yield %b : i32 + } -> tensor + + %3 = linalg.init_tensor [] : tensor + %4 = linalg.generic #attrs + ins(%2, %farg1 : tensor, tensor) + outs(%3 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors + %8 = cmpi slt, %arg0, %arg1 : i32 + linalg.yield %8 : i1 + } -> tensor + %5 = tensor.extract %4[] : tensor + cond_br %5, ^bb2(%2 : tensor), ^bb3(%2 : tensor) + +^bb2(%6: tensor): // pred: ^bb1 + %7 = linalg.init_tensor [10] : tensor<10xi32> + %9 = linalg.generic #broadcast_attrs + ins(%6: tensor) + outs(%7: tensor<10xi32>) { + ^bb(%a: i32, %b: i32) : + linalg.yield %a : i32 + } -> tensor<10xi32> + + br ^bb1(%9 : tensor<10xi32>) + +^bb3(%10: tensor): // pred: ^bb1 + return %10 : tensor +} + +// Test aggresively detensoring all detensorable ops. +// +// DET-ALL-LABEL: func @main +// DET-ALL-SAME: (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor) +// DET-ALL: br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>) +// DET-ALL: ^[[bb1]](%{{.*}}: tensor<10xi32>) +// DET-ALL: linalg.init_tensor [] : tensor +// DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor) { +// DET-ALL: ^bb0(%{{.*}}: i32, %{{.*}}: i32): // no predecessors +// DET-ALL: %{{.*}} = addi %{{.*}}, %{{.*}} +// DET-ALL: linalg.yield %{{.*}} : i32 +// DET-ALL: } -> tensor +// DET-ALL: tensor.extract %{{.*}}[] : tensor +// DET-ALL: tensor.extract %{{.*}}[] : tensor +// DET-ALL: cmpi slt, %{{.*}}, %{{.*}} : i32 +// DET-ALL: tensor.extract %{{.*}}[] : tensor +// DET-ALL: tensor.extract %{{.*}}[] : tensor +// DET-ALL: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32) +// DET-ALL: ^[[bb2]](%{{.*}}: i32) +// DET-ALL: tensor.from_elements %{{.*}} : tensor<1xi32> +// DET-ALL: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor +// DET-ALL: linalg.init_tensor [10] : tensor<10xi32> +// DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor) outs(%{{.*}} : tensor<10xi32>) { +// DET-ALL: ^bb0(%{{.*}}: i32, %{{.*}}: i32): +// DET-ALL: linalg.yield %{{.*}} : i32 +// DET-ALL: } -> tensor<10xi32> +// DET-ALL: br ^[[bb1]](%{{.*}} : tensor<10xi32>) +// DET-ALL: ^[[bb3]](%{{.*}}: i32) +// DET-ALL: tensor.from_elements %{{.*}} : tensor<1xi32> +// DET-ALL: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor +// DET-ALL: return %{{.*}} : tensor +// DET-ALL: } + +// Try to detensor pure control-flow. However, that fails since the potential +// detensorable component contains some ops that cannot be detensored. +// +// DET-CF-LABEL: func @main +// DET-CF-SAME: (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor) +// DET-CF: br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>) +// DET-CF: ^bb1(%{{.*}}: tensor<10xi32>) +// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor) { +// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor, tensor) outs(%{{.*}} : tensor) { +// DET-CF: cond_br %{{.*}}, ^bb2(%{{.*}} : tensor), ^bb3(%{{.*}} : tensor) +// DET-CF: ^bb2(%{{.*}}: tensor) +// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor) outs(%{{.*}} : tensor<10xi32>) { +// DET-CF: br ^bb1(%{{.*}} : tensor<10xi32>) +// DET-CF: ^bb3(%{{.*}}: tensor) +// DET-CF: return %{{.*}} : tensor +// DET-CF: } diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s + +#map0 = affine_map<() -> ()> + +#attrs = { + indexing_maps = [#map0, #map0, #map0], + iterator_types = [] +} + +func @main() -> () attributes {} { + %c0 = constant 0 : i32 + %0 = tensor.from_elements %c0 : tensor<1xi32> + %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor + %c10 = constant 10 : i32 + %1 = tensor.from_elements %c10 : tensor<1xi32> + %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor + br ^bb1(%reshaped0 : tensor) + +^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 + %3 = linalg.init_tensor [] : tensor + %4 = linalg.generic #attrs + ins(%2, %reshaped1 : tensor, tensor) + outs(%3 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors + %8 = cmpi slt, %arg0, %arg1 : i32 + linalg.yield %8 : i1 + } -> tensor + %5 = tensor.extract %4[] : tensor + cond_br %5, ^bb2(%2 : tensor), ^bb3 + +^bb2(%6: tensor): // pred: ^bb1 + %7 = linalg.init_tensor [] : tensor + %8 = linalg.generic #attrs + ins(%6, %6 : tensor, tensor) + outs(%7 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors + %9 = addi %arg0, %arg1 : i32 + linalg.yield %9 : i32 + } -> tensor + br ^bb1(%8 : tensor) + +^bb3: // pred: ^bb1 + return +} + +// CHECK-LABEL: func @main +// CHECK-NEXT: constant 0 : i32 +// CHECK-NEXT: constant 10 +// CHECK-NEXT: br ^[[bb1:.*]](%{{.*}} : i32) +// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32) +// CHECK-NEXT: %{{.*}} = cmpi slt, %{{.*}}, %{{.*}} +// CHECK-NEXT: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]] +// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32) +// CHECK-NEXT: %{{.*}} = addi %{{.*}}, %{{.*}} +// CHECK-NEXT: br ^[[bb1]](%{{.*}} : i32) +// CHECK-NEXT: ^[[bb3]]: +// CHECK-NEXT: return +// CHECK-NEXT: } diff --git a/mlir/test/Dialect/Linalg/detensorized_0d.mlir b/mlir/test/Dialect/Linalg/detensorized_0d.mlir --- a/mlir/test/Dialect/Linalg/detensorized_0d.mlir +++ b/mlir/test/Dialect/Linalg/detensorized_0d.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize=aggressive-mode | FileCheck %s #map = affine_map<() -> ()> diff --git a/mlir/test/Dialect/Linalg/detensorized_while.mlir b/mlir/test/Dialect/Linalg/detensorized_while.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/detensorized_while.mlir +++ /dev/null @@ -1,53 +0,0 @@ -// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s - -#map0 = affine_map<() -> ()> - -#attrs = { - indexing_maps = [#map0, #map0, #map0], - iterator_types = [] -} - -func @main(%farg0: tensor, %farg1: tensor) -> tensor attributes {} { - br ^bb1(%farg0 : tensor) - -^bb1(%0: tensor): // 2 preds: ^bb0, ^bb2 - %1 = linalg.init_tensor [] : tensor - %2 = linalg.generic #attrs - ins(%0, %farg1 : tensor, tensor) - outs(%1 : tensor) { - ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors - %8 = cmpi slt, %arg0, %arg1 : i32 - linalg.yield %8 : i1 - } -> tensor - %3 = tensor.extract %2[] : tensor - cond_br %3, ^bb2(%0 : tensor), ^bb3(%0 : tensor) - -^bb2(%4: tensor): // pred: ^bb1 - %5 = linalg.init_tensor [] : tensor - %6 = linalg.generic #attrs - ins(%4, %4 : tensor, tensor) - outs(%5 : tensor) { - ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors - %8 = addi %arg0, %arg1 : i32 - linalg.yield %8 : i32 - } -> tensor - br ^bb1(%6 : tensor) - -^bb3(%7: tensor): // pred: ^bb1 - return %7 : tensor -} - -// CHECK-LABEL: func @main -// CHECK-SAME: (%{{.*}}: tensor, %{{.*}}: tensor) -// CHECK: tensor.extract {{.*}} -// CHECK: br ^[[bb1:.*]](%{{.*}} : i32) -// CHECK: ^[[bb1]](%{{.*}}: i32) -// CHECK: cmpi slt, {{.*}} -// CHECK: cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32) -// CHECK: ^[[bb2]](%{{.*}}: i32) -// CHECK: addi {{.*}} -// CHECK: br ^[[bb1]](%{{.*}} : i32) -// CHECK: ^[[bb3]](%{{.*}}: i32) -// CHECK: tensor.from_elements {{.*}} -// CHECK: linalg.tensor_reshape {{.*}} -// CHECK: return %{{.*}} : tensor