diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -80,6 +80,13 @@ void eraseFunctionResults(Operation *op, ArrayRef resultIndices, unsigned originalNumResults, Type newType); +/// Get and set a FunctionLike operation's type signature. +FunctionType getFunctionType(Operation *op); +void setFunctionType(Operation *op, FunctionType newType); + +/// Get a FunctionLike operation's body. +Region &getFunctionBody(Operation *op); + } // namespace impl namespace OpTrait { @@ -134,7 +141,9 @@ /// Returns true if this function is external, i.e. it has no body. bool isExternal() { return empty(); } - Region &getBody() { return this->getOperation()->getRegion(0); } + Region &getBody() { + return ::mlir::impl::getFunctionBody(this->getOperation()); + } /// Delete all blocks from this function. void eraseBody() { @@ -198,7 +207,7 @@ /// hide this one if the concrete class does not use FunctionType for the /// function type under the hood. FunctionType getType() { - return getTypeAttr().getValue().template cast(); + return ::mlir::impl::getFunctionType(this->getOperation()); } /// Return the type of this function without the specified arguments and @@ -542,15 +551,7 @@ template void FunctionLike::setType(FunctionType newType) { - SmallVector nameBuf; - auto oldType = getType(); - auto *concreteOp = static_cast(this); - - for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++) - concreteOp->removeAttr(getArgAttrName(i, nameBuf)); - for (int i = newType.getNumResults(), e = oldType.getNumResults(); i < e; i++) - concreteOp->removeAttr(getResultAttrName(i, nameBuf)); - (*concreteOp)->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + ::mlir::impl::setFunctionType(this->getOperation(), newType); } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -421,6 +421,20 @@ using ConversionPattern::matchAndRewrite; }; +/// Add a pattern to the given pattern list to convert the signature of a +/// FunctionLike op with the given type converter. +void populateFunctionLikeTypeConversionPattern( + StringRef functionLikeOpName, OwningRewritePatternList &patterns, + MLIRContext *ctx, TypeConverter &converter); + +template +void populateFunctionLikeTypeConversionPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx, + TypeConverter &converter) { + populateFunctionLikeTypeConversionPattern(FuncOpT::getOperationName(), + patterns, ctx, converter); +} + /// Add a pattern to the given pattern list to convert the signature of a FuncOp /// with the given type converter. void populateFuncOpTypeConversionPattern(OwningRewritePatternList &patterns, diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp --- a/mlir/lib/IR/FunctionSupport.cpp +++ b/mlir/lib/IR/FunctionSupport.cpp @@ -99,3 +99,35 @@ op->removeAttr(nameAttr); } } + +//===----------------------------------------------------------------------===// +// Function type signature. +//===----------------------------------------------------------------------===// + +FunctionType mlir::impl::getFunctionType(Operation *op) { + assert(op->hasTrait()); + return op->getAttrOfType(mlir::impl::getTypeAttrName()) + .getValue() + .cast(); +} + +void mlir::impl::setFunctionType(Operation *op, FunctionType newType) { + assert(op->hasTrait()); + SmallVector nameBuf; + FunctionType oldType = getFunctionType(op); + + for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++) + op->removeAttr(getArgAttrName(i, nameBuf)); + for (int i = newType.getNumResults(), e = oldType.getNumResults(); i < e; i++) + op->removeAttr(getResultAttrName(i, nameBuf)); + op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); +} + +//===----------------------------------------------------------------------===// +// Function body. +//===----------------------------------------------------------------------===// + +Region &mlir::impl::getFunctionBody(Operation *op) { + assert(op->hasTrait()); + return op->getRegion(0); +} diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/FunctionSupport.h" #include "mlir/Rewrite/PatternApplicator.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/SetVector.h" @@ -2515,41 +2516,51 @@ } /// Create a default conversion pattern that rewrites the type signature of a -/// FuncOp. +/// FunctionLike op. namespace { -struct FuncOpSignatureConversion : public OpConversionPattern { - FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) - : OpConversionPattern(converter, ctx) {} +struct FunctionLikeSignatureConversion : public ConversionPattern { + FunctionLikeSignatureConversion(StringRef functionLikeOpName, + MLIRContext *ctx, TypeConverter &converter) + : ConversionPattern(functionLikeOpName, /*benefit=*/1, converter, ctx) {} - /// Hook for derived classes to implement combined matching and rewriting. + /// Hook to implement combined matching and rewriting for FunctionLike ops. LogicalResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - FunctionType type = funcOp.getType(); + FunctionType type = mlir::impl::getFunctionType(op); // Convert the original function types. TypeConverter::SignatureConversion result(type.getNumInputs()); SmallVector newResults; if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) || failed(typeConverter->convertTypes(type.getResults(), newResults)) || - failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter, - &result))) + failed(rewriter.convertRegionTypes(&mlir::impl::getFunctionBody(op), + *typeConverter, &result))) return failure(); // Update the function signature in-place. - rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType(FunctionType::get(funcOp.getContext(), - result.getConvertedTypes(), newResults)); - }); + auto newType = FunctionType::get(rewriter.getContext(), + result.getConvertedTypes(), newResults); + + rewriter.updateRootInPlace( + op, [&] { mlir::impl::setFunctionType(op, newType); }); + return success(); } }; } // end anonymous namespace +void mlir::populateFunctionLikeTypeConversionPattern( + StringRef functionLikeOpName, OwningRewritePatternList &patterns, + MLIRContext *ctx, TypeConverter &converter) { + patterns.insert(functionLikeOpName, ctx, + converter); +} + void mlir::populateFuncOpTypeConversionPattern( OwningRewritePatternList &patterns, MLIRContext *ctx, TypeConverter &converter) { - patterns.insert(ctx, converter); + populateFunctionLikeTypeConversionPattern(patterns, ctx, converter); } //===----------------------------------------------------------------------===//