diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -151,12 +151,16 @@ linalg-on-tensor op is checked to see whether *all* its operands can be detensored. If so, those operands are converted to their primitive counterparts and the linalg op is replaced by an equivalent op that takes - those new primitive values as operands. Therefore, the detensoring process - can be divided into 2 main logical phases: + those new primitive values as operands. Therefore, detensoring an op can be + divided into 2 main logical phases: 1. Detect/match an op that can be detensored. 2. Detensor the operands of the op and replace it with a primitive equivalent. + + In addition to detensoring individual ops, this pass detensors internal + control flow inside a function. All blocks except for the entry block are + detensored by converting their arguments whenever possible. }]; } diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h @@ -16,7 +16,9 @@ namespace mlir { // Forward declarations. +class ConversionTarget; class MLIRContext; +class Operation; class OwningRewritePatternList; class TypeConverter; @@ -26,13 +28,42 @@ MLIRContext *ctx, TypeConverter &converter); -/// Add a pattern to the given pattern list to rewrite branch operations and -/// `return` to use operands that have been legalized by the conversion -/// framework. This can only be done if the branch operation implements the -/// BranchOpInterface. Only needed for partial conversions. -void populateBranchOpInterfaceAndReturnOpTypeConversionPattern( +/// Add a pattern to the given pattern list to rewrite branch operations to use +/// operands that have been legalized by the conversion framework. This can only +/// be done if the branch operation implements the BranchOpInterface. Only +/// needed for partial conversions. +void populateBranchOpInterfaceTypeConversionPattern( OwningRewritePatternList &patterns, MLIRContext *ctx, TypeConverter &converter); + +/// Mark terminators as legal if they have the ReturnLike trait or implement the +/// BranchOpInterface and have valid types. If they do not implement the trait +/// or interface, mark them as illegal no matter what. +/// +/// Note: Non-terminator ops are marked as legal. +void populateBranchOpInterfaceLegality(ConversionTarget &target, + TypeConverter &converter); + +/// Return true if op is a BranchOpInterface op whose operands are all legal +/// according to converter. +bool isLegalForBranchOpInterfaceTypeConversionPattern(Operation *op, + TypeConverter &converter); + +/// Add a pattern to the given pattern list to rewrite `return` ops to use +/// operands that have been legalized by the conversion framework. +void populateReturnOpTypeConversionPattern(OwningRewritePatternList &patterns, + MLIRContext *ctx, + TypeConverter &converter); + +/// For ReturnLike ops (except `return`), return True. If op is a `return` && +/// returnOpAlwaysLegal is false, legalize op according to converter. Otherwise, +/// return false. +bool isLegalForReturnOpTypeConversionPattern(Operation *op, + TypeConverter &converter, + bool returnOpAlwaysLegal = false); + +/// Return true if op is neither BranchOpInterface nor ReturnLike. +bool isNotBranchOrReturnLikeOp(Operation *op); } // end namespace mlir #endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_ diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -474,6 +474,12 @@ Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion = nullptr); + /// Convert the types of block arguments within the given region except for + /// the entry region. This replaces each non-entry block with a new block + /// containing the updated signature. + LogicalResult convertNonEntryRegionTypes(Region *region, + TypeConverter &converter); + /// Replace all the uses of the block argument `from` with value `to`. void replaceUsesOfBlockArgument(BlockArgument from, Value to); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -21,6 +21,20 @@ using namespace mlir; using namespace mlir::linalg; +static Value sourceMaterializationCallback(OpBuilder &builder, Type type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); + // A detensored value is converted back by creating a new tensor from its + // element(s). + auto createNewTensorOp = builder.create( + loc, inputs[0].getType(), inputs[0]); + + // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to + // a tensor instead. + return builder.create( + loc, type, createNewTensorOp, ArrayRef{}); +} + namespace { /// Defines the criteria a TensorType must follow in order to be considered /// "detensorable". @@ -64,6 +78,29 @@ } }; +/// A conversion pattern for detensoring internal (non-entry) blocks within a +/// function. +struct FunctionNonEntryBlockConversion : public ConversionPattern { + FunctionNonEntryBlockConversion(StringRef functionLikeOpName, + MLIRContext *ctx, TypeConverter &converter) + : ConversionPattern(functionLikeOpName, /*benefit=*/1, converter, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.startRootUpdate(op); + + if (failed(rewriter.convertNonEntryRegionTypes( + &mlir::impl::getFunctionBody(op), *typeConverter))) { + rewriter.cancelRootUpdate(op); + return failure(); + } + + rewriter.finalizeRootUpdate(op); + return success(); + } +}; + class DetensorizeTypeConverter : public TypeConverter { public: DetensorizeTypeConverter() { @@ -84,18 +121,8 @@ return builder.create(loc, inputs[0], ValueRange{}); }); - // A detensored value is converted back by creating a new tensor from its - // element(s). - addSourceMaterialization([](OpBuilder &builder, Type type, - ValueRange inputs, Location loc) -> Value { - auto createNewTensorOp = builder.create( - loc, inputs[0].getType(), inputs[0]); - - // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to - // a tensor instead. - return builder.create( - loc, type, createNewTensorOp, ArrayRef{}); - }); + addSourceMaterialization(sourceMaterializationCallback); + addArgumentMaterialization(sourceMaterializationCallback); } }; @@ -139,22 +166,43 @@ OwningRewritePatternList patterns; ConversionTarget target(*context); - target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; }); - target.addLegalDialect(); target.addDynamicallyLegalOp([&](GenericOp op) { - // If any of the operands or results cannot be detensored, the op is - // considered legal and won't be detensored. - return llvm::any_of( - op.getShapedOperandTypes(), [](ShapedType shapedType) { - assert(shapedType.isa()); - return !canBeDetensored(shapedType.cast()); - }); + // If any of the operands or results cannot be detensored (i.e. they are + // all legal according the DetensorizeTypeConverter), the op is considered + // legal and won't be detensored. + return llvm::any_of(op.getShapedOperandTypes(), + [&](ShapedType shapedType) { + return typeConverter.isLegal(shapedType); + }); }); - patterns.insert(typeConverter, context); + target.addDynamicallyLegalOp([&](FuncOp op) { + // A function is legal if all of its non-entry blocks are legal. We don't + // legalize the entry block (i.e. the function's signature) since + // detensoring can't happen along external calling convention boundaries, + // which we conservatively approximate as all function signatures. + return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) { + return typeConverter.isLegal(block.getArgumentTypes()); + }); + }); + + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return isNotBranchOrReturnLikeOp(op) || + isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter) || + isLegalForReturnOpTypeConversionPattern( + op, typeConverter, /*returnOpAlwaysLegal*/ true); + }); - if (failed( - applyPartialConversion(getFunction(), target, std::move(patterns)))) + patterns.insert(typeConverter, context); + patterns.insert(FuncOp::getOperationName(), + context, typeConverter); + // Since non-entry block arguments get detensorized, we also need to update + // the control flow inside the function to reflect the correct types. + populateBranchOpInterfaceTypeConversionPattern(patterns, context, + typeConverter); + + if (failed(applyFullConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); OwningRewritePatternList canonPatterns; @@ -162,8 +210,6 @@ if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(canonPatterns)))) signalPassFailure(); - - // TODO Properly handle control flow within function boundaries. } }; } // namespace diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp @@ -40,39 +40,17 @@ target.addDynamicallyLegalOp( [&](CallOp op) { return typeConverter.isLegal(op); }); - populateBranchOpInterfaceAndReturnOpTypeConversionPattern(patterns, context, - typeConverter); + populateBranchOpInterfaceTypeConversionPattern(patterns, context, + typeConverter); + populateReturnOpTypeConversionPattern(patterns, context, typeConverter); target.addLegalOp(); - target.addDynamicallyLegalOp( - [&](ReturnOp op) { return typeConverter.isLegal(op); }); - // Mark terminators as legal if they have the ReturnLike trait or - // implement the BranchOpInterface and have valid types. If they do not - // implement the trait or interface, mark them as illegal no matter what. + target.markUnknownOpDynamicallyLegal([&](Operation *op) { - // If it is not a terminator, ignore it. - if (!op->mightHaveTrait()) - return true; - // If it is not the last operation in the block, also ignore it. We do - // this to handle unknown operations, as well. - Block *block = op->getBlock(); - if (!block || &block->back() != op) - return true; - // ReturnLike operations have to be legalized with their parent. For - // return this is handled, for other ops they remain as is. - if (op->hasTrait()) - return true; - // All successor operands of branch like operations must be rewritten. - if (auto branchOp = dyn_cast(op)) { - for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) { - auto successorOperands = branchOp.getSuccessorOperands(p); - if (successorOperands.hasValue() && - !typeConverter.isLegal(successorOperands.getValue().getTypes())) - return false; - } - return true; - } - return false; + return isNotBranchOrReturnLikeOp(op) || + isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter) || + isLegalForReturnOpTypeConversionPattern(op, typeConverter); }); if (failed(applyFullConversion(module, target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp @@ -102,9 +102,61 @@ }; } // end anonymous namespace -void mlir::populateBranchOpInterfaceAndReturnOpTypeConversionPattern( +void mlir::populateBranchOpInterfaceTypeConversionPattern( OwningRewritePatternList &patterns, MLIRContext *ctx, TypeConverter &typeConverter) { - patterns.insert( - typeConverter, ctx); + patterns.insert(typeConverter, ctx); +} + +bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern( + Operation *op, TypeConverter &converter) { + // All successor operands of branch like operations must be rewritten. + if (auto branchOp = dyn_cast(op)) { + for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) { + auto successorOperands = branchOp.getSuccessorOperands(p); + if (successorOperands.hasValue() && + !converter.isLegal(successorOperands.getValue().getTypes())) + return false; + } + return true; + } + + return false; +} + +void mlir::populateReturnOpTypeConversionPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx, + TypeConverter &typeConverter) { + patterns.insert(typeConverter, ctx); +} + +bool mlir::isLegalForReturnOpTypeConversionPattern(Operation *op, + TypeConverter &converter, + bool returnOpAlwaysLegal) { + // If this is a `return` and the user pass wants to convert/transform across + // function boundaries, then `converter` is invoked to check whether the the + // `return` op is legal. + if (dyn_cast(op) && !returnOpAlwaysLegal) + return converter.isLegal(op); + + // ReturnLike operations have to be legalized with their parent. For + // return this is handled, for other ops they remain as is. + if (op->hasTrait()) + return true; + + return false; +} + +bool mlir::isNotBranchOrReturnLikeOp(Operation *op) { + // If it is not a terminator, ignore it. + if (!op->mightHaveTrait()) + return true; + + // If it is not the last operation in the block, also ignore it. We do + // this to handle unknown operations, as well. + Block *block = op->getBlock(); + if (!block || &block->back() != op) + return true; + + return false; } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -749,6 +749,10 @@ convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion); + /// Convert the types of non-entry block arguments within the given region. + LogicalResult convertNonEntryRegionTypes(Region *region, + TypeConverter &converter); + //===--------------------------------------------------------------------===// // Rewriter Notification Hooks //===--------------------------------------------------------------------===// @@ -1150,13 +1154,25 @@ if (region->empty()) return nullptr; - // Convert the arguments of each block within the region. + if (failed(convertNonEntryRegionTypes(region, converter))) + return failure(); + FailureOr newEntry = convertBlockSignature(®ion->front(), converter, entryConversion); + return newEntry; +} + +LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( + Region *region, TypeConverter &converter) { + argConverter.setConverter(region, &converter); + if (region->empty()) + return success(); + + // Convert the arguments of each block within the region. for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) if (failed(convertBlockSignature(&block, converter))) return failure(); - return newEntry; + return success(); } //===----------------------------------------------------------------------===// @@ -1323,6 +1339,11 @@ return impl->convertRegionTypes(region, converter, entryConversion); } +LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes( + Region *region, TypeConverter &converter) { + return impl->convertNonEntryRegionTypes(region, converter); +} + void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, Value to) { LLVM_DEBUG({ diff --git a/mlir/test/Dialect/Linalg/detensorized_while.mlir b/mlir/test/Dialect/Linalg/detensorized_while.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorized_while.mlir @@ -0,0 +1,53 @@ +// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s + +#map0 = affine_map<() -> ()> + +#attrs = { + indexing_maps = [#map0, #map0, #map0], + iterator_types = [] +} + +func @main(%farg0: tensor, %farg1: tensor) -> tensor attributes {} { + br ^bb1(%farg0 : tensor) + +^bb1(%0: tensor): // 2 preds: ^bb0, ^bb2 + %1 = linalg.init_tensor [] : tensor + %2 = linalg.generic #attrs + ins(%0, %farg1 : tensor, tensor) + outs(%1 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors + %8 = cmpi slt, %arg0, %arg1 : i32 + linalg.yield %8 : i1 + } -> tensor + %3 = tensor.extract %2[] : tensor + cond_br %3, ^bb2(%0 : tensor), ^bb3(%0 : tensor) + +^bb2(%4: tensor): // pred: ^bb1 + %5 = linalg.init_tensor [] : tensor + %6 = linalg.generic #attrs + ins(%4, %4 : tensor, tensor) + outs(%5 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors + %8 = addi %arg0, %arg1 : i32 + linalg.yield %8 : i32 + } -> tensor + br ^bb1(%6 : tensor) + +^bb3(%7: tensor): // pred: ^bb1 + return %7 : tensor +} + +// CHECK-LABEL: func @main +// CHECK-SAME: (%{{.*}}: tensor, %{{.*}}: tensor) +// CHECK: tensor.extract {{.*}} +// CHECK: br ^[[bb1:.*]](%{{.*}} : i32) +// CHECK: ^[[bb1]](%{{.*}}: i32) +// CHECK: cmpi slt, {{.*}} +// CHECK: cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32) +// CHECK: ^[[bb2]](%{{.*}}: i32) +// CHECK: addi {{.*}} +// CHECK: br ^[[bb1]](%{{.*}} : i32) +// CHECK: ^[[bb3]](%{{.*}}: i32) +// CHECK: tensor.from_elements {{.*}} +// CHECK: linalg.tensor_reshape {{.*}} +// CHECK: return %{{.*}} : tensor