diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -62,6 +62,7 @@ /// Create a pass to convert Linalg operations to equivalent operations that /// work on primitive types, if possible. std::unique_ptr createLinalgDetensorizePass(); +std::unique_ptr createFuncDetensorizePass(); /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// producer (consumer) generic operation by expanding the dimensionality of the diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -160,4 +160,9 @@ }]; } +def FuncDetensorize : FunctionPass<"func-detensorize"> { + let summary = "TODO"; + let constructor = "mlir::createFuncDetensorizePass()"; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h @@ -30,9 +30,12 @@ /// `return` to use operands that have been legalized by the conversion /// framework. This can only be done if the branch operation implements the /// BranchOpInterface. Only needed for partial conversions. -void populateBranchOpInterfaceAndReturnOpTypeConversionPattern( +void populateBranchOpInterfaceTypeConversionPattern( OwningRewritePatternList &patterns, MLIRContext *ctx, TypeConverter &converter); +void populateReturnOpTypeConversionPattern(OwningRewritePatternList &patterns, + MLIRContext *ctx, + TypeConverter &converter); } // end namespace mlir #endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_ diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -88,6 +88,26 @@ // element(s). addSourceMaterialization([](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { + if (inputs.empty()) + return builder.create( + loc, builder.getI32Type(), + builder.getZeroAttr(builder.getI32Type())); + for (auto input : inputs) { + input.dump(); + } + auto createNewTensorOp = builder.create( + loc, inputs[0].getType(), inputs[0]); + + // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to + // a tensor instead. + return builder + .create(loc, type, createNewTensorOp, + ArrayRef{}) + .getResult(); + }); + + addArgumentMaterialization([](OpBuilder &builder, Type type, + ValueRange inputs, Location loc) -> Value { auto createNewTensorOp = builder.create( loc, inputs[0].getType(), inputs[0]); @@ -144,11 +164,10 @@ target.addDynamicallyLegalOp([&](GenericOp op) { // If any of the operands or results cannot be detensored, the op is // considered legal and won't be detensored. - return llvm::any_of( - op.getShapedOperandTypes(), [](ShapedType shapedType) { - assert(shapedType.isa()); - return !canBeDetensored(shapedType.cast()); - }); + return llvm::any_of(op.getShapedOperandTypes(), + [&](ShapedType shapedType) { + return typeConverter.isLegal(shapedType); + }); }); patterns.insert(typeConverter, context); @@ -162,8 +181,99 @@ if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(canonPatterns)))) signalPassFailure(); + } +}; - // TODO Properly handle control flow within function boundaries. +struct FunctionInternalBlockConversion : public ConversionPattern { + FunctionInternalBlockConversion(StringRef functionLikeOpName, + MLIRContext *ctx, TypeConverter &converter) + : ConversionPattern(functionLikeOpName, /*benefit=*/1, converter, ctx) {} + + /// Hook to implement combined matching and rewriting for FunctionLike ops. + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + FunctionType type = mlir::impl::getFunctionType(op); + rewriter.startRootUpdate(op); + + // Convert the original function types. + TypeConverter::SignatureConversion result(type.getNumInputs()); + result.addInputs(type.getInputs()); + + SmallVector newResults; + if (failed(rewriter.convertRegionTypes(&mlir::impl::getFunctionBody(op), + *typeConverter, &result))) { + rewriter.cancelRootUpdate(op); + return failure(); + } + + rewriter.finalizeRootUpdate(op); + return success(); + } +}; + +struct FuncDetensorize : public FuncDetensorizeBase { + void runOnFunction() override { + auto func = getFunction(); + auto *context = &getContext(); + + DetensorizeTypeConverter typeConverter; + OwningRewritePatternList patterns; + ConversionTarget target(*context); + + patterns.insert(FuncOp::getOperationName(), + context, typeConverter); + populateBranchOpInterfaceTypeConversionPattern(patterns, context, + typeConverter); + + target.addDynamicallyLegalOp([&](FuncOp op) { + return std::all_of(std::next(op.getBody().begin()), op.getBody().end(), + [&](Block &block) { + return typeConverter.isLegal( + block.getArgumentTypes()); + }); + }); + + target.addLegalOp(); + + // Mark terminators as legal if they have the ReturnLike trait or + // implement the BranchOpInterface and have valid types. If they do not + // implement the trait or interface, mark them as illegal no matter what. + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + // return true; + // If it is not a terminator, ignore it. + if (!op->mightHaveTrait()) + return true; + // If it is not the last operation in the block, also ignore it. We do + // this to handle unknown operations, as well. + Block *block = op->getBlock(); + if (!block || &block->back() != op) + return true; + // ReturnLike operations have to be legalized with their parent. For + // return this is handled, for other ops they remain as is. + if (op->hasTrait()) + return true; + // All successor operands of branch like operations must be rewritten. + if (auto branchOp = dyn_cast(op)) { + for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) { + auto successorOperands = branchOp.getSuccessorOperands(p); + if (successorOperands.hasValue() && + !typeConverter.isLegal(successorOperands.getValue().getTypes())) + return false; + } + return true; + } + return false; + }); + + if (failed(applyPartialConversion(func, target, std::move(patterns)))) + signalPassFailure(); + + OwningRewritePatternList canonPatterns; + canonPatterns.insert(context); + + if (failed(applyPatternsAndFoldGreedily(func, std::move(canonPatterns)))) + signalPassFailure(); } }; } // namespace @@ -171,3 +281,7 @@ std::unique_ptr mlir::createLinalgDetensorizePass() { return std::make_unique(); } + +std::unique_ptr mlir::createFuncDetensorizePass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp @@ -40,8 +40,9 @@ target.addDynamicallyLegalOp( [&](CallOp op) { return typeConverter.isLegal(op); }); - populateBranchOpInterfaceAndReturnOpTypeConversionPattern(patterns, context, - typeConverter); + populateBranchOpInterfaceTypeConversionPattern(patterns, context, + typeConverter); + populateReturnOpTypeConversionPattern(patterns, context, typeConverter); target.addLegalOp(); target.addDynamicallyLegalOp( diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp @@ -102,9 +102,14 @@ }; } // end anonymous namespace -void mlir::populateBranchOpInterfaceAndReturnOpTypeConversionPattern( +void mlir::populateBranchOpInterfaceTypeConversionPattern( OwningRewritePatternList &patterns, MLIRContext *ctx, TypeConverter &typeConverter) { - patterns.insert( - typeConverter, ctx); + patterns.insert(typeConverter, ctx); +} + +void mlir::populateReturnOpTypeConversionPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx, + TypeConverter &typeConverter) { + patterns.insert(typeConverter, ctx); } diff --git a/mlir/test/Dialect/Linalg/detensorized_while.mlir b/mlir/test/Dialect/Linalg/detensorized_while.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorized_while.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s -linalg-detensorize -func-detensorize -canonicalize | FileCheck %s + +func @main(%cst: tensor, %cst_0: tensor) -> tensor attributes {iree.module.export} { +// %cst = constant dense<1> : tensor +// %cst_0 = constant dense<3> : tensor + br ^bb1(%cst : tensor) +^bb1(%0: tensor): // 2 preds: ^bb0, ^bb2 + %1 = linalg.init_tensor [] : tensor + %2 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%0, %cst_0 : tensor, tensor) outs(%1 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors + %8 = cmpi slt, %arg0, %arg1 : i32 + linalg.yield %8 : i1 + } -> tensor + %3 = tensor.extract %2[] : tensor + cond_br %3, ^bb2(%0 : tensor), ^bb3(%0 : tensor) +^bb2(%4: tensor): // pred: ^bb1 + %5 = linalg.init_tensor [] : tensor + %6 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%4, %4 : tensor, tensor) outs(%5 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors + %8 = addi %arg0, %arg1 : i32 + linalg.yield %8 : i32 + } -> tensor + br ^bb1(%6 : tensor) +^bb3(%7: tensor): // pred: ^bb1 + return %7 : tensor +} +// CHECK-LABEL: func @main