diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -80,6 +80,13 @@ void eraseFunctionResults(Operation *op, ArrayRef resultIndices, unsigned originalNumResults, Type newType); +/// Get and set a FunctionLike operation's type signature. +FunctionType getFunctionType(Operation *op); +void setFunctionType(Operation *op, FunctionType newType); + +/// Get a FunctionLike operation's body. +Region &getFunctionBody(Operation *op); + } // namespace impl namespace OpTrait { @@ -134,7 +141,9 @@ /// Returns true if this function is external, i.e. it has no body. bool isExternal() { return empty(); } - Region &getBody() { return this->getOperation()->getRegion(0); } + Region &getBody() { + return ::mlir::impl::getFunctionBody(this->getOperation()); + } /// Delete all blocks from this function. void eraseBody() { @@ -198,7 +207,7 @@ /// hide this one if the concrete class does not use FunctionType for the /// function type under the hood. FunctionType getType() { - return getTypeAttr().getValue().template cast(); + return ::mlir::impl::getFunctionType(this->getOperation()); } /// Return the type of this function without the specified arguments and @@ -542,15 +551,7 @@ template void FunctionLike::setType(FunctionType newType) { - SmallVector nameBuf; - auto oldType = getType(); - auto *concreteOp = static_cast(this); - - for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++) - concreteOp->removeAttr(getArgAttrName(i, nameBuf)); - for (int i = newType.getNumResults(), e = oldType.getNumResults(); i < e; i++) - concreteOp->removeAttr(getResultAttrName(i, nameBuf)); - (*concreteOp)->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + ::mlir::impl::setFunctionType(this->getOperation(), newType); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp --- a/mlir/lib/IR/FunctionSupport.cpp +++ b/mlir/lib/IR/FunctionSupport.cpp @@ -99,3 +99,35 @@ op->removeAttr(nameAttr); } } + +//===----------------------------------------------------------------------===// +// Function type signature. +//===----------------------------------------------------------------------===// + +FunctionType mlir::impl::getFunctionType(Operation *op) { + assert(op->hasTrait()); + return op->getAttrOfType(mlir::impl::getTypeAttrName()) + .getValue() + .cast(); +} + +void mlir::impl::setFunctionType(Operation *op, FunctionType newType) { + assert(op->hasTrait()); + SmallVector nameBuf; + FunctionType oldType = getFunctionType(op); + + for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++) + op->removeAttr(getArgAttrName(i, nameBuf)); + for (int i = newType.getNumResults(), e = oldType.getNumResults(); i < e; i++) + op->removeAttr(getResultAttrName(i, nameBuf)); + op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); +} + +//===----------------------------------------------------------------------===// +// Function body. +//===----------------------------------------------------------------------===// + +Region &mlir::impl::getFunctionBody(Operation *op) { + assert(op->hasTrait()); + return op->getRegion(0); +} diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/FunctionSupport.h" #include "mlir/Rewrite/PatternApplicator.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/SetVector.h" @@ -74,7 +75,7 @@ /// A utility function to log a successful result for the given reason. template -static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { +static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&... args) { LLVM_DEBUG({ os.unindent(); os.startLine() << "} -> SUCCESS"; @@ -87,7 +88,7 @@ /// A utility function to log a failure result for the given reason. template -static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { +static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&... args) { LLVM_DEBUG({ os.unindent(); os.startLine() << "} -> FAILURE : " @@ -2517,30 +2518,39 @@ /// Create a default conversion pattern that rewrites the type signature of a /// FuncOp. namespace { -struct FuncOpSignatureConversion : public OpConversionPattern { - FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) - : OpConversionPattern(converter, ctx) {} +struct FuncOpSignatureConversion : public ConversionPattern { + explicit FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) + : ConversionPattern(/*benefit=*/1, converter, MatchAnyOpTypeTag()) {} - /// Hook for derived classes to implement combined matching and rewriting. + /// Hook to implement combined matching and rewriting for FunctionLike ops. LogicalResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - FunctionType type = funcOp.getType(); + if (!op->hasTrait()) + return failure(); + + // Get the old function signature. Generic implementation inlined from + // FunctionLike.h. + FunctionType oldType = mlir::impl::getFunctionType(op); // Convert the original function types. - TypeConverter::SignatureConversion result(type.getNumInputs()); + TypeConverter::SignatureConversion result(oldType.getNumInputs()); SmallVector newResults; - if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) || - failed(typeConverter->convertTypes(type.getResults(), newResults)) || - failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter, - &result))) + if (failed( + typeConverter->convertSignatureArgs(oldType.getInputs(), result)) || + failed(typeConverter->convertTypes(oldType.getResults(), newResults)) || + failed(rewriter.convertRegionTypes(&mlir::impl::getFunctionBody(op), + *typeConverter, &result))) return failure(); - // Update the function signature in-place. - rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType(FunctionType::get(funcOp.getContext(), - result.getConvertedTypes(), newResults)); - }); + // Update the function signature in-place. Generic implementation inlined + // from FunctionLike.h + auto newType = FunctionType::get(rewriter.getContext(), + result.getConvertedTypes(), newResults); + + rewriter.updateRootInPlace( + op, [&] { mlir::impl::setFunctionType(op, newType); }); + return success(); } };