diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h --- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h +++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h @@ -22,13 +22,16 @@ namespace mlir { class AffineForOp; class AffineMap; -class FuncOp; class LoopLikeOpInterface; struct MemRefRegion; class OpBuilder; class Value; class ValueRange; +namespace func { +class FuncOp; +} // namespace func + namespace scf { class ForOp; class ParallelOp; @@ -79,7 +82,7 @@ /// Promotes all single iteration AffineForOp's in the Function, i.e., moves /// their body into the containing Block. -void promoteSingleIterationLoops(FuncOp f); +void promoteSingleIterationLoops(func::FuncOp f); /// Skew the operations in an affine.for's body with the specified /// operation-wise shifts. The shifts are with respect to the original execution @@ -92,7 +95,7 @@ /// Identify valid and profitable bands of loops to tile. This is currently just /// a temporary placeholder to test the mechanics of tiled code generation. /// Returns all maximal outermost perfect loop nests to tile. -void getTileableBands(FuncOp f, +void getTileableBands(func::FuncOp f, std::vector> *bands); /// Tiles the specified band of perfectly nested loops creating tile-space loops @@ -260,7 +263,7 @@ ArrayRef numProcessors); /// Gathers all AffineForOps in 'builtin.func' grouped by loop depth. -void gatherLoops(FuncOp func, +void gatherLoops(func::FuncOp func, std::vector> &depthToLoops); /// Creates an AffineForOp while ensuring that the lower and upper bounds are diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -18,6 +18,9 @@ #include namespace mlir { +namespace func { +class FuncOp; +} // namespace func class AffineForOp; @@ -28,53 +31,56 @@ /// Creates a simplification pass for affine structures (maps and sets). In /// addition, this pass also normalizes memrefs to have the trivial (identity) /// layout map. -std::unique_ptr> createSimplifyAffineStructuresPass(); +std::unique_ptr> +createSimplifyAffineStructuresPass(); /// Creates a loop invariant code motion pass that hoists loop invariant /// operations out of affine loops. -std::unique_ptr> +std::unique_ptr> createAffineLoopInvariantCodeMotionPass(); /// Creates a pass to convert all parallel affine.for's into 1-d affine.parallel /// ops. -std::unique_ptr> createAffineParallelizePass(); +std::unique_ptr> createAffineParallelizePass(); /// Apply normalization transformations to affine loop-like ops. -std::unique_ptr> createAffineLoopNormalizePass(); +std::unique_ptr> createAffineLoopNormalizePass(); /// Performs packing (or explicit copying) of accessed memref regions into /// buffers in the specified faster memory space through either pointwise copies /// or DMA operations. -std::unique_ptr> createAffineDataCopyGenerationPass( +std::unique_ptr> createAffineDataCopyGenerationPass( unsigned slowMemorySpace, unsigned fastMemorySpace, unsigned tagMemorySpace = 0, int minDmaTransferSize = 1024, uint64_t fastMemCapacityBytes = std::numeric_limits::max()); /// Overload relying on pass options for initialization. -std::unique_ptr> createAffineDataCopyGenerationPass(); +std::unique_ptr> +createAffineDataCopyGenerationPass(); /// Creates a pass to replace affine memref accesses by scalars using store to /// load forwarding and redundant load elimination; consequently also eliminate /// dead allocs. -std::unique_ptr> createAffineScalarReplacementPass(); +std::unique_ptr> +createAffineScalarReplacementPass(); /// Creates a pass that transforms perfectly nested loops with independent /// bounds into a single loop. -std::unique_ptr> createLoopCoalescingPass(); +std::unique_ptr> createLoopCoalescingPass(); /// Creates a loop fusion pass which fuses loops according to type of fusion /// specified in `fusionMode`. Buffers of size less than or equal to /// `localBufSizeThreshold` are promoted to memory space `fastMemorySpace`. -std::unique_ptr> +std::unique_ptr> createLoopFusionPass(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0, bool maximalFusion = false, enum FusionMode fusionMode = FusionMode::Greedy); /// Creates a pass to perform tiling on loop nests. -std::unique_ptr> +std::unique_ptr> createLoopTilingPass(uint64_t cacheSizeBytes); /// Overload relying on pass options for initialization. -std::unique_ptr> createLoopTilingPass(); +std::unique_ptr> createLoopTilingPass(); /// Creates a loop unrolling pass with the provided parameters. /// 'getUnrollFactor' is a function callback for clients to supply a function @@ -82,7 +88,7 @@ /// factors supplied through other means. If -1 is passed as the unrollFactor /// and no callback is provided, anything passed from the command-line (if at /// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor). -std::unique_ptr> createLoopUnrollPass( +std::unique_ptr> createLoopUnrollPass( int unrollFactor = -1, bool unrollUpToFactor = false, bool unrollFull = false, const std::function &getUnrollFactor = nullptr); @@ -90,19 +96,19 @@ /// Creates a loop unroll jam pass to unroll jam by the specified factor. A /// factor of -1 lets the pass use the default factor or the one on the command /// line if provided. -std::unique_ptr> +std::unique_ptr> createLoopUnrollAndJamPass(int unrollJamFactor = -1); /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. -std::unique_ptr> createPipelineDataTransferPass(); +std::unique_ptr> createPipelineDataTransferPass(); /// Creates a pass to vectorize loops, operations and data types using a /// target-independent, n-D super-vector abstraction. -std::unique_ptr> +std::unique_ptr> createSuperVectorizePass(ArrayRef virtualVectorSize); /// Overload relying on pass options for initialization. -std::unique_ptr> createSuperVectorizePass(); +std::unique_ptr> createSuperVectorizePass(); //===----------------------------------------------------------------------===// // Registration diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -21,10 +21,13 @@ class AffineIfOp; class AffineParallelOp; class DominanceInfo; -class FuncOp; class Operation; class PostDominanceInfo; +namespace func { +class FuncOp; +} // namespace func + namespace memref { class AllocOp; } // namespace memref @@ -96,7 +99,7 @@ /// Replace affine store and load accesses by scalars by forwarding stores to /// loads and eliminate invariant affine loads; consequently, eliminate dead /// allocs. -void affineScalarReplace(FuncOp f, DominanceInfo &domInfo, +void affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo, PostDominanceInfo &postDomInfo); /// Vectorizes affine loops in 'loops' using the n-D vectorization factors in diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -4,6 +4,10 @@ #include "mlir/Pass/Pass.h" namespace mlir { +namespace func { +class FuncOp; +} // namespace func + namespace bufferization { struct AnalysisBufferizationOptions; @@ -28,7 +32,7 @@ /// Creates a pass that finalizes a partial bufferization by removing remaining /// bufferization.to_tensor and bufferization.to_memref operations. -std::unique_ptr> createFinalizingBufferizePass(); +std::unique_ptr> createFinalizingBufferizePass(); /// Create a pass that bufferizes all ops that implement BufferizableOpInterface /// with One-Shot Bufferize. diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -11,7 +11,7 @@ include "mlir/Pass/PassBase.td" -def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> { +def BufferDeallocation : Pass<"buffer-deallocation", "func::FuncOp"> { let summary = "Adds all required dealloc operations for all allocations in " "the input program"; let description = [{ @@ -88,7 +88,7 @@ let constructor = "mlir::bufferization::createBufferDeallocationPass()"; } -def BufferHoisting : Pass<"buffer-hoisting", "FuncOp"> { +def BufferHoisting : Pass<"buffer-hoisting", "func::FuncOp"> { let summary = "Optimizes placement of allocation operations by moving them " "into common dominators and out of nested regions"; let description = [{ @@ -98,7 +98,7 @@ let constructor = "mlir::bufferization::createBufferHoistingPass()"; } -def BufferLoopHoisting : Pass<"buffer-loop-hoisting", "FuncOp"> { +def BufferLoopHoisting : Pass<"buffer-loop-hoisting", "func::FuncOp"> { let summary = "Optimizes placement of allocation operations by moving them " "out of loop nests"; let description = [{ @@ -133,7 +133,7 @@ let dependentDialects = ["memref::MemRefDialect"]; } -def FinalizingBufferize : Pass<"finalizing-bufferize", "FuncOp"> { +def FinalizingBufferize : Pass<"finalizing-bufferize", "func::FuncOp"> { let summary = "Finalize a partial bufferization"; let description = [{ A bufferize pass that finalizes a partial bufferization by removing @@ -231,7 +231,7 @@ let constructor = "mlir::bufferization::createOneShotBufferizePass()"; } -def PromoteBuffersToStack : Pass<"promote-buffers-to-stack", "FuncOp"> { +def PromoteBuffersToStack : Pass<"promote-buffers-to-stack", "func::FuncOp"> { let summary = "Promotes heap-based allocations to automatically managed " "stack-based allocations"; let description = [{ diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.h b/mlir/include/mlir/Dialect/Func/IR/FuncOps.h --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.h +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.h @@ -11,10 +11,11 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" @@ -29,4 +30,25 @@ #include "mlir/Dialect/Func/IR/FuncOpsDialect.h.inc" +namespace mlir { +/// FIXME: This is a temporary using directive to ease the transition of FuncOp +/// to the Func dialect. This will be removed after all uses are updated. +using FuncOp = func::FuncOp; +} // namespace mlir + +namespace llvm { + +/// Allow stealing the low bits of FuncOp. +template <> +struct PointerLikeTypeTraits { + static inline void *getAsVoidPointer(mlir::func::FuncOp val) { + return const_cast(val.getAsOpaquePointer()); + } + static inline mlir::func::FuncOp getFromVoidPointer(void *p) { + return mlir::func::FuncOp::getFromOpaquePointer(p); + } + static constexpr int numLowBitsAvailable = 3; +}; +} // namespace llvm + #endif // MLIR_DIALECT_FUNC_IR_OPS_H diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td @@ -13,6 +13,7 @@ include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/IR/FunctionInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -201,6 +202,133 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +def FuncOp : Func_Op<"func", [ + AffineScope, AutomaticAllocationScope, CallableOpInterface, + FunctionOpInterface, IsolatedFromAbove, Symbol +]> { + let summary = "An operation with a name containing a single `SSACFG` region"; + let description = [{ + Operations within the function cannot implicitly capture values defined + outside of the function, i.e. Functions are `IsolatedFromAbove`. All + external references must use function arguments or attributes that establish + a symbolic connection (e.g. symbols referenced by name via a string + attribute like SymbolRefAttr). An external function declaration (used when + referring to a function declared in some other module) has no body. While + the MLIR textual form provides a nice inline syntax for function arguments, + they are internally represented as “block arguments” to the first block in + the region. + + Only dialect attribute names may be specified in the attribute dictionaries + for function arguments, results, or the function itself. + + Example: + + ```mlir + // External function definitions. + func.func @abort() + func.func @scribble(i32, i64, memref) -> f64 + + // A function that returns its argument twice: + func.func @count(%x: i64) -> (i64, i64) + attributes {fruit: "banana"} { + return %x, %x: i64, i64 + } + + // A function with an argument attribute + func.func @example_fn_arg(%x: i32 {swift.self = unit}) + + // A function with a result attribute + func.func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64}) + + // A function with an attribute + func.func @example_fn_attr() attributes {dialectName.attrName = false} + ``` + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$type, + OptionalAttr:$sym_visibility); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs, + CArg<"ArrayRef", "{}">:$argAttrs) + >]; + let extraClassDeclaration = [{ + static FuncOp create(Location location, StringRef name, FunctionType type, + ArrayRef attrs = {}); + static FuncOp create(Location location, StringRef name, FunctionType type, + Operation::dialect_attr_range attrs); + static FuncOp create(Location location, StringRef name, FunctionType type, + ArrayRef attrs, + ArrayRef argAttrs); + + /// Create a deep copy of this function and all of its blocks, remapping any + /// operands that use values outside of the function using the map that is + /// provided (leaving them alone if no entry is present). If the mapper + /// contains entries for function arguments, these arguments are not + /// included in the new function. Replaces references to cloned sub-values + /// with the corresponding value that is copied, and adds those mappings to + /// the mapper. + FuncOp clone(BlockAndValueMapping &mapper); + FuncOp clone(); + + /// Clone the internal blocks and attributes from this function into dest. + /// Any cloned blocks are appended to the back of dest. This function + /// asserts that the attributes of the current function and dest are + /// compatible. + void cloneInto(FuncOp dest, BlockAndValueMapping &mapper); + + //===------------------------------------------------------------------===// + // CallableOpInterface + //===------------------------------------------------------------------===// + + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } + + /// Returns the results types that the callable region produces when + /// executed. + ArrayRef getCallableResults() { return getType().getResults(); } + + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getType().getResults(); } + + /// Verify the type attribute of this function. Returns failure and emits + /// an error if the attribute is invalid. + LogicalResult verifyType() { + auto type = getTypeAttr().getValue(); + if (!type.isa()) { + return emitOpError() + << "requires '" << FunctionOpInterface::getTypeAttrName() + << "' attribute of function type"; + } + return success(); + } + + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + bool isDeclaration() { return isExternal(); } + }]; + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h --- a/mlir/include/mlir/Dialect/GPU/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Passes.h @@ -23,6 +23,10 @@ } // namespace llvm namespace mlir { +namespace func { +class FuncOp; +} // namespace func + /// Pass that moves ops which are likely an index computation into gpu.launch /// body. std::unique_ptr createGpuLauchSinkIndexComputationsPass(); @@ -33,7 +37,7 @@ createGpuKernelOutliningPass(StringRef dataLayoutStr = StringRef()); /// Rewrites a function region so that GPU ops execute asynchronously. -std::unique_ptr> createGpuAsyncRegionPass(); +std::unique_ptr> createGpuAsyncRegionPass(); /// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect. void populateGpuAllReducePatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -14,7 +14,9 @@ #include "mlir/IR/OpDefinition.h" namespace mlir { +namespace func { class FuncOp; +} // namespace func namespace linalg { @@ -155,7 +157,8 @@ static StringRef getDependenceTypeStr(DependenceType depType); // Builds a linalg dependence graph for the ops of type LinalgOp under `f`. - static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f); + static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, + func::FuncOp f); LinalgDependenceGraph(Aliases &aliases, ArrayRef ops); /// Returns the X such that op -> X is a dependence of type dt. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -18,6 +18,13 @@ #include "mlir/Pass/Pass.h" namespace mlir { +namespace func { +class FuncOp; +} // namespace func + +// TODO: Remove when all references have been updated. +using FuncOp = func::FuncOp; + namespace bufferization { struct AnalysisBufferizationOptions; } // namespace bufferization @@ -31,29 +38,32 @@ std::unique_ptr createLinalgNamedOpConversionPass(); -std::unique_ptr> +std::unique_ptr> createLinalgTilingPass(ArrayRef tileSizes = {}, linalg::LinalgTilingLoopType loopType = linalg::LinalgTilingLoopType::Loops); -std::unique_ptr> +std::unique_ptr> createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca); -std::unique_ptr> createLinalgPromotionPass(); +std::unique_ptr> createLinalgPromotionPass(); -std::unique_ptr> createLinalgInlineScalarOperandsPass(); +std::unique_ptr> +createLinalgInlineScalarOperandsPass(); /// Create a pass to convert Linalg operations to scf.for loops and /// memref.load/memref.store accesses. -std::unique_ptr> createConvertLinalgToLoopsPass(); +std::unique_ptr> createConvertLinalgToLoopsPass(); /// Create a pass to convert Linalg operations to scf.parallel loops and /// memref.load/memref.store accesses. -std::unique_ptr> createConvertLinalgToParallelLoopsPass(); +std::unique_ptr> +createConvertLinalgToParallelLoopsPass(); /// Create a pass to convert Linalg operations to affine.for loops and /// affine_load/affine_store accesses. /// Placeholder for now, this is NYI. -std::unique_ptr> createConvertLinalgToAffineLoopsPass(); +std::unique_ptr> +createConvertLinalgToAffineLoopsPass(); /// This pass implements a cross-dialect bufferization approach and performs an /// analysis to determine which op operands and results may be bufferized in the @@ -68,11 +78,11 @@ /// Create a pass to convert Linalg operations which work on tensors to use /// buffers instead. -std::unique_ptr> createLinalgBufferizePass(); +std::unique_ptr> createLinalgBufferizePass(); /// Create a pass to convert named Linalg operations to Linalg generic /// operations. -std::unique_ptr> createLinalgGeneralizationPass(); +std::unique_ptr> createLinalgGeneralizationPass(); /// Create a pass to convert Linalg operations to equivalent operations that /// work on primitive types, if possible. @@ -82,27 +92,28 @@ /// Linalg strategy passes. //===----------------------------------------------------------------------===// /// Create a LinalgStrategyTileAndFusePass. -std::unique_ptr> createLinalgStrategyTileAndFusePass( +std::unique_ptr> +createLinalgStrategyTileAndFusePass( StringRef opName = "", const linalg::LinalgTilingAndFusionOptions &opt = {}, const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); /// Create a LinalgStrategyTilePass. -std::unique_ptr> createLinalgStrategyTilePass( +std::unique_ptr> createLinalgStrategyTilePass( StringRef opName = "", const linalg::LinalgTilingOptions &opt = linalg::LinalgTilingOptions(), const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); /// Create a LinalgStrategyPadPass. -std::unique_ptr> createLinalgStrategyPadPass( +std::unique_ptr> createLinalgStrategyPadPass( StringRef opName = "", const linalg::LinalgPaddingOptions &opt = linalg::LinalgPaddingOptions(), const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); /// Create a LinalgStrategyPromotePass. -std::unique_ptr> createLinalgStrategyPromotePass( +std::unique_ptr> createLinalgStrategyPromotePass( StringRef opName = "", const linalg::LinalgPromotionOptions &opt = linalg::LinalgPromotionOptions(), @@ -110,24 +121,25 @@ linalg::LinalgTransformationFilter()); /// Create a LinalgStrategyGeneralizePass. -std::unique_ptr> createLinalgStrategyGeneralizePass( +std::unique_ptr> createLinalgStrategyGeneralizePass( StringRef opName = "", const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); /// Create a LinalgStrategyDecomposePass. // TODO: if/when we need finer control add an `opName` parameter. -std::unique_ptr> createLinalgStrategyDecomposePass( +std::unique_ptr> createLinalgStrategyDecomposePass( const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); /// Create a LinalgStrategyInterchangePass. -std::unique_ptr> createLinalgStrategyInterchangePass( +std::unique_ptr> +createLinalgStrategyInterchangePass( ArrayRef iteratorInterchange = {}, const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); /// Create a LinalgStrategyVectorizePass. -std::unique_ptr> createLinalgStrategyVectorizePass( +std::unique_ptr> createLinalgStrategyVectorizePass( StringRef opName = "", linalg::LinalgVectorizationOptions opt = linalg::LinalgVectorizationOptions(), @@ -136,20 +148,22 @@ bool padVectorize = false); /// Create a LinalgStrategyEnablePass. -std::unique_ptr> createLinalgStrategyEnablePass( +std::unique_ptr> createLinalgStrategyEnablePass( linalg::LinalgEnablingOptions opt = linalg::LinalgEnablingOptions(), const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); /// Create a LinalgStrategyLowerVectorsPass. -std::unique_ptr> createLinalgStrategyLowerVectorsPass( +std::unique_ptr> +createLinalgStrategyLowerVectorsPass( linalg::LinalgVectorLoweringOptions opt = linalg::LinalgVectorLoweringOptions(), const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); /// Create a LinalgStrategyRemoveMarkersPass. -std::unique_ptr> createLinalgStrategyRemoveMarkersPass(); +std::unique_ptr> +createLinalgStrategyRemoveMarkersPass(); //===----------------------------------------------------------------------===// // Registration diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h @@ -10,7 +10,9 @@ #define MLIR_DIALECT_LINALG_TRANSFORMS_HOISTING_H_ namespace mlir { +namespace func { class FuncOp; +} // namespace func namespace linalg { @@ -27,11 +29,11 @@ /// results in scf::ForOp yielding the value that originally transited through /// memory. // TODO: generalize on a per-need basis. -void hoistRedundantVectorTransfers(FuncOp func); +void hoistRedundantVectorTransfers(func::FuncOp func); /// Same behavior as `hoistRedundantVectorTransfers` but works on tensors /// instead of buffers. -void hoistRedundantVectorTransfersOnTensor(FuncOp func); +void hoistRedundantVectorTransfersOnTensor(func::FuncOp func); } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Quant/Passes.h b/mlir/include/mlir/Dialect/Quant/Passes.h --- a/mlir/include/mlir/Dialect/Quant/Passes.h +++ b/mlir/include/mlir/Dialect/Quant/Passes.h @@ -19,18 +19,22 @@ #include "mlir/Pass/Pass.h" namespace mlir { +namespace func { +class FuncOp; +} // namespace func + namespace quant { /// Creates a pass that converts quantization simulation operations (i.e. /// FakeQuant and those like it) to casts into/out of supported QuantizedTypes. -std::unique_ptr> createConvertSimulatedQuantPass(); +std::unique_ptr> createConvertSimulatedQuantPass(); /// Creates a pass that converts constants followed by a qbarrier to a /// constant whose value is quantized. This is typically one of the last /// passes done when lowering to express actual quantized arithmetic in a /// low level representation. Because it modifies the constant, it is /// destructive and cannot be undone. -std::unique_ptr> createConvertConstPass(); +std::unique_ptr> createConvertConstPass(); //===----------------------------------------------------------------------===// // Registration diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -19,7 +19,6 @@ #include "llvm/ADT/STLExtras.h" namespace mlir { -class FuncOp; class Location; class Operation; class OpBuilder; @@ -28,6 +27,10 @@ class ValueRange; class Value; +namespace func { +class FuncOp; +} // namespace func + namespace scf { class IfOp; class ForOp; @@ -68,8 +71,9 @@ /// collide with another FuncOp name. // TODO: support more than single-block regions. // TODO: more flexible constant handling. -FailureOr outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, - Region ®ion, StringRef funcName); +FailureOr outlineSingleBlockRegion(RewriterBase &rewriter, + Location loc, Region ®ion, + StringRef funcName); /// Outline the then and/or else regions of `ifOp` as follows: /// - if `thenFn` is not null, `thenFnName` must be specified and the `then` @@ -79,8 +83,8 @@ /// Creates new FuncOps and thus cannot be used in a FuncOp pass. /// The client is responsible for providing a unique `thenFnName`/`elseFnName` /// that will not collide with another FuncOp name. -LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, FuncOp *thenFn, - StringRef thenFnName, FuncOp *elseFn, +LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, func::FuncOp *thenFn, + StringRef thenFnName, func::FuncOp *elseFn, StringRef elseFnName); /// Get a list of innermost parallel loops contained in `rootOp`. Innermost diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -15,6 +15,7 @@ #define MLIR_DIALECT_SHAPE_IR_SHAPE_H #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -13,12 +13,10 @@ #ifndef MLIR_IR_BUILTINOPS_H_ #define MLIR_IR_BUILTINOPS_H_ -#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OwningOpRef.h" #include "mlir/IR/RegionKindInterface.h" #include "mlir/IR/SymbolTable.h" -#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -32,18 +30,6 @@ #include "mlir/IR/BuiltinOps.h.inc" namespace llvm { -/// Allow stealing the low bits of FuncOp. -template <> -struct PointerLikeTypeTraits { - static inline void *getAsVoidPointer(mlir::FuncOp val) { - return const_cast(val.getAsOpaquePointer()); - } - static inline mlir::FuncOp getFromVoidPointer(void *p) { - return mlir::FuncOp::getFromOpaquePointer(p); - } - static constexpr int numLowBitsAvailable = 3; -}; - /// Allow stealing the low bits of ModuleOp. template <> struct PointerLikeTypeTraits { diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -15,11 +15,9 @@ #define BUILTIN_OPS include "mlir/IR/BuiltinDialect.td" -include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/RegionKindInterface.td" include "mlir/IR/SymbolInterfaces.td" -include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/DataLayoutInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -28,131 +26,6 @@ class Builtin_Op traits = []> : Op; -//===----------------------------------------------------------------------===// -// FuncOp -//===----------------------------------------------------------------------===// - -def FuncOp : Builtin_Op<"func", [ - AffineScope, AutomaticAllocationScope, CallableOpInterface, - FunctionOpInterface, IsolatedFromAbove, Symbol -]> { - let summary = "An operation with a name containing a single `SSACFG` region"; - let description = [{ - Operations within the function cannot implicitly capture values defined - outside of the function, i.e. Functions are `IsolatedFromAbove`. All - external references must use function arguments or attributes that establish - a symbolic connection (e.g. symbols referenced by name via a string - attribute like SymbolRefAttr). An external function declaration (used when - referring to a function declared in some other module) has no body. While - the MLIR textual form provides a nice inline syntax for function arguments, - they are internally represented as “block arguments” to the first block in - the region. - - Only dialect attribute names may be specified in the attribute dictionaries - for function arguments, results, or the function itself. - - Example: - - ```mlir - // External function definitions. - func @abort() - func @scribble(i32, i64, memref) -> f64 - - // A function that returns its argument twice: - func @count(%x: i64) -> (i64, i64) - attributes {fruit: "banana"} { - return %x, %x: i64, i64 - } - - // A function with an argument attribute - func @example_fn_arg(%x: i32 {swift.self = unit}) - - // A function with a result attribute - func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64}) - - // A function with an attribute - func @example_fn_attr() attributes {dialectName.attrName = false} - ``` - }]; - - let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$type, - OptionalAttr:$sym_visibility); - let regions = (region AnyRegion:$body); - - let builders = [OpBuilder<(ins - "StringRef":$name, "FunctionType":$type, - CArg<"ArrayRef", "{}">:$attrs, - CArg<"ArrayRef", "{}">:$argAttrs) - >]; - let extraClassDeclaration = [{ - static FuncOp create(Location location, StringRef name, FunctionType type, - ArrayRef attrs = {}); - static FuncOp create(Location location, StringRef name, FunctionType type, - Operation::dialect_attr_range attrs); - static FuncOp create(Location location, StringRef name, FunctionType type, - ArrayRef attrs, - ArrayRef argAttrs); - - /// Create a deep copy of this function and all of its blocks, remapping any - /// operands that use values outside of the function using the map that is - /// provided (leaving them alone if no entry is present). If the mapper - /// contains entries for function arguments, these arguments are not - /// included in the new function. Replaces references to cloned sub-values - /// with the corresponding value that is copied, and adds those mappings to - /// the mapper. - FuncOp clone(BlockAndValueMapping &mapper); - FuncOp clone(); - - /// Clone the internal blocks and attributes from this function into dest. - /// Any cloned blocks are appended to the back of dest. This function - /// asserts that the attributes of the current function and dest are - /// compatible. - void cloneInto(FuncOp dest, BlockAndValueMapping &mapper); - - //===------------------------------------------------------------------===// - // CallableOpInterface - //===------------------------------------------------------------------===// - - /// Returns the region on the current operation that is callable. This may - /// return null in the case of an external callable object, e.g. an external - /// function. - ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } - - /// Returns the results types that the callable region produces when - /// executed. - ArrayRef getCallableResults() { return getType().getResults(); } - - //===------------------------------------------------------------------===// - // FunctionOpInterface Methods - //===------------------------------------------------------------------===// - - /// Returns the argument types of this function. - ArrayRef getArgumentTypes() { return getType().getInputs(); } - - /// Returns the result types of this function. - ArrayRef getResultTypes() { return getType().getResults(); } - - /// Verify the type attribute of this function. Returns failure and emits - /// an error if the attribute is invalid. - LogicalResult verifyType() { - auto type = getTypeAttr().getValue(); - if (!type.isa()) - return emitOpError("requires '" + FunctionOpInterface::getTypeAttrName() + - "' attribute of function type"); - return success(); - } - - //===------------------------------------------------------------------===// - // SymbolOpInterface Methods - //===------------------------------------------------------------------===// - - bool isDeclaration() { return isExternal(); } - }]; - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; -} - //===----------------------------------------------------------------------===// // ModuleOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td --- a/mlir/include/mlir/IR/FunctionInterfaces.td +++ b/mlir/include/mlir/IR/FunctionInterfaces.td @@ -142,7 +142,7 @@ /// Block argument iterator types. using BlockArgListType = Region::BlockArgListType; using args_iterator = BlockArgListType::iterator; - + //===------------------------------------------------------------------===// // Body Handling //===------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -22,6 +22,7 @@ #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -22,6 +22,7 @@ #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h" #include "../PassDetail.h" #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -10,6 +10,7 @@ #include "../PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" diff --git a/mlir/lib/Dialect/Affine/Transforms/PassDetail.h b/mlir/lib/Dialect/Affine/Transforms/PassDetail.h --- a/mlir/lib/Dialect/Affine/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Affine/Transforms/PassDetail.h @@ -10,6 +10,7 @@ #define DIALECT_AFFINE_TRANSFORMS_PASSDETAIL_H_ #include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/BlockAndValueMapping.h" diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/BlockAndValueMapping.h" diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BlockAndValueMapping.h" diff --git a/mlir/lib/Dialect/Bufferization/Transforms/PassDetail.h b/mlir/lib/Dialect/Bufferization/Transforms/PassDetail.h --- a/mlir/lib/Dialect/Bufferization/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Bufferization/Transforms/PassDetail.h @@ -9,6 +9,7 @@ #ifndef DIALECT_BUFFERIZATION_TRANSFORMS_PASSDETAIL_H_ #define DIALECT_BUFFERIZATION_TRANSFORMS_PASSDETAIL_H_ +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp --- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp +++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -21,6 +22,7 @@ #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/APFloat.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/FormatVariadic.h" @@ -56,6 +58,12 @@ return true; } + /// All functions can be inlined. + bool isLegalToInline(Region *, Region *, bool, + BlockAndValueMapping &) const final { + return true; + } + //===--------------------------------------------------------------------===// // Transformation Hooks //===--------------------------------------------------------------------===// @@ -208,6 +216,149 @@ return value.isa() && type.isa(); } +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, + ArrayRef attrs) { + OpBuilder builder(location->getContext()); + OperationState state(location, getOperationName()); + FuncOp::build(builder, state, name, type, attrs); + return cast(Operation::create(state)); +} +FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, + Operation::dialect_attr_range attrs) { + SmallVector attrRef(attrs); + return create(location, name, type, llvm::makeArrayRef(attrRef)); +} +FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, + ArrayRef attrs, + ArrayRef argAttrs) { + FuncOp func = create(location, name, type, attrs); + func.setAllArgAttrs(argAttrs); + return func; +} + +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(FunctionOpInterface::getTypeAttrName(), + TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) + return; + assert(type.getNumInputs() == argAttrs.size()); + function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, + /*resultAttrs=*/llvm::None); +} + +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, buildFuncType); +} + +void FuncOp::print(OpAsmPrinter &p) { + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); +} + +LogicalResult FuncOp::verify() { + // If this function is external there is nothing to do. + if (isExternal()) + return success(); + + // Verify that the argument list of the function and the arg list of the entry + // block line up. The trait already verified that the number of arguments is + // the same between the signature and the block. + auto fnInputTypes = getType().getInputs(); + Block &entryBlock = front(); + for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i) + if (fnInputTypes[i] != entryBlock.getArgument(i).getType()) + return emitOpError("type of entry block argument #") + << i << '(' << entryBlock.getArgument(i).getType() + << ") must match the type of the corresponding argument in " + << "function signature(" << fnInputTypes[i] << ')'; + + return success(); +} + +/// Clone the internal blocks from this function into dest and all attributes +/// from this function to dest. +void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { + // Add the attributes of this function to dest. + llvm::MapVector newAttrMap; + for (const auto &attr : dest->getAttrs()) + newAttrMap.insert({attr.getName(), attr.getValue()}); + for (const auto &attr : (*this)->getAttrs()) + newAttrMap.insert({attr.getName(), attr.getValue()}); + + auto newAttrs = llvm::to_vector(llvm::map_range( + newAttrMap, [](std::pair attrPair) { + return NamedAttribute(attrPair.first, attrPair.second); + })); + dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs)); + + // Clone the body. + getBody().cloneInto(&dest.getBody(), mapper); +} + +/// Create a deep copy of this function and all of its blocks, remapping +/// any operands that use values outside of the function using the map that is +/// provided (leaving them alone if no entry is present). Replaces references +/// to cloned sub-values with the corresponding value that is copied, and adds +/// those mappings to the mapper. +FuncOp FuncOp::clone(BlockAndValueMapping &mapper) { + // Create the new function. + FuncOp newFunc = cast(getOperation()->cloneWithoutRegions()); + + // If the function has a body, then the user might be deleting arguments to + // the function by specifying them in the mapper. If so, we don't add the + // argument to the input type vector. + if (!isExternal()) { + FunctionType oldType = getType(); + + unsigned oldNumArgs = oldType.getNumInputs(); + SmallVector newInputs; + newInputs.reserve(oldNumArgs); + for (unsigned i = 0; i != oldNumArgs; ++i) + if (!mapper.contains(getArgument(i))) + newInputs.push_back(oldType.getInput(i)); + + /// If any of the arguments were dropped, update the type and drop any + /// necessary argument attributes. + if (newInputs.size() != oldNumArgs) { + newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, + oldType.getResults())); + + if (ArrayAttr argAttrs = getAllArgAttrs()) { + SmallVector newArgAttrs; + newArgAttrs.reserve(newInputs.size()); + for (unsigned i = 0; i != oldNumArgs; ++i) + if (!mapper.contains(getArgument(i))) + newArgAttrs.push_back(argAttrs[i]); + newFunc.setAllArgAttrs(newArgAttrs); + } + } + } + + /// Clone the current function into the new one and return it. + cloneInto(newFunc, mapper); + return newFunc; +} +FuncOp FuncOp::clone() { + BlockAndValueMapping mapper; + return clone(mapper); +} + //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/GPU/Utils.h" diff --git a/mlir/lib/Dialect/GPU/Transforms/PassDetail.h b/mlir/lib/Dialect/GPU/Transforms/PassDetail.h --- a/mlir/lib/Dialect/GPU/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/GPU/Transforms/PassDetail.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/IR/BuiltinOps.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/SCF.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -15,6 +15,7 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/SCF.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h --- a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h @@ -10,6 +10,7 @@ #define DIALECT_LINALG_TRANSFORMS_PASSDETAIL_H_ #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Dialect.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" diff --git a/mlir/lib/Dialect/Quant/Transforms/PassDetail.h b/mlir/lib/Dialect/Quant/Transforms/PassDetail.h --- a/mlir/lib/Dialect/Quant/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Quant/Transforms/PassDetail.h @@ -9,6 +9,7 @@ #ifndef DIALECT_QUANT_TRANSFORMS_PASSDETAIL_H_ #define DIALECT_QUANT_TRANSFORMS_PASSDETAIL_H_ +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/mlir/lib/Dialect/SCF/Transforms/PassDetail.h b/mlir/lib/Dialect/SCF/Transforms/PassDetail.h --- a/mlir/lib/Dialect/SCF/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/SCF/Transforms/PassDetail.h @@ -6,9 +6,10 @@ // //===----------------------------------------------------------------------===// -#ifndef DIALECT_LOOPOPS_TRANSFORMS_PASSDETAIL_H_ -#define DIALECT_LOOPOPS_TRANSFORMS_PASSDETAIL_H_ +#ifndef DIALECT_SCF_TRANSFORMS_PASSDETAIL_H_ +#define DIALECT_SCF_TRANSFORMS_PASSDETAIL_H_ +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -39,4 +40,4 @@ } // namespace mlir -#endif // DIALECT_LOOPOPS_TRANSFORMS_PASSDETAIL_H_ +#endif // DIALECT_SCF_TRANSFORMS_PASSDETAIL_H_ diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp @@ -14,6 +14,7 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/lib/Dialect/Shape/Transforms/PassDetail.h b/mlir/lib/Dialect/Shape/Transforms/PassDetail.h --- a/mlir/lib/Dialect/Shape/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Shape/Transforms/PassDetail.h @@ -9,6 +9,7 @@ #ifndef DIALECT_SHAPE_TRANSFORMS_PASSDETAIL_H_ #define DIALECT_SHAPE_TRANSFORMS_PASSDETAIL_H_ +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -11,6 +11,7 @@ #include "mlir/Conversion/Passes.h" #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" diff --git a/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h b/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h --- a/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h @@ -9,6 +9,7 @@ #ifndef DIALECT_TENSOR_TRANSFORMS_PASSDETAIL_H_ #define DIALECT_TENSOR_TRANSFORMS_PASSDETAIL_H_ +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/mlir/lib/Dialect/Vector/Transforms/PassDetail.h b/mlir/lib/Dialect/Vector/Transforms/PassDetail.h --- a/mlir/lib/Dialect/Vector/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Vector/Transforms/PassDetail.h @@ -9,6 +9,7 @@ #ifndef DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_ #define DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_ +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -16,10 +16,8 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" -#include "llvm/ADT/MapVector.h" using namespace mlir; @@ -72,149 +70,6 @@ addInterfaces(); } -//===----------------------------------------------------------------------===// -// FuncOp -//===----------------------------------------------------------------------===// - -FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, - ArrayRef attrs) { - OpBuilder builder(location->getContext()); - OperationState state(location, getOperationName()); - FuncOp::build(builder, state, name, type, attrs); - return cast(Operation::create(state)); -} -FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, - Operation::dialect_attr_range attrs) { - SmallVector attrRef(attrs); - return create(location, name, type, llvm::makeArrayRef(attrRef)); -} -FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, - ArrayRef attrs, - ArrayRef argAttrs) { - FuncOp func = create(location, name, type, attrs); - func.setAllArgAttrs(argAttrs); - return func; -} - -void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, - FunctionType type, ArrayRef attrs, - ArrayRef argAttrs) { - state.addAttribute(SymbolTable::getSymbolAttrName(), - builder.getStringAttr(name)); - state.addAttribute(function_interface_impl::getTypeAttrName(), - TypeAttr::get(type)); - state.attributes.append(attrs.begin(), attrs.end()); - state.addRegion(); - - if (argAttrs.empty()) - return; - assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, - /*resultAttrs=*/llvm::None); -} - -ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { - auto buildFuncType = - [](Builder &builder, ArrayRef argTypes, ArrayRef results, - function_interface_impl::VariadicFlag, - std::string &) { return builder.getFunctionType(argTypes, results); }; - - return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); -} - -void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); -} - -LogicalResult FuncOp::verify() { - // If this function is external there is nothing to do. - if (isExternal()) - return success(); - - // Verify that the argument list of the function and the arg list of the entry - // block line up. The trait already verified that the number of arguments is - // the same between the signature and the block. - auto fnInputTypes = getType().getInputs(); - Block &entryBlock = front(); - for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i) - if (fnInputTypes[i] != entryBlock.getArgument(i).getType()) - return emitOpError("type of entry block argument #") - << i << '(' << entryBlock.getArgument(i).getType() - << ") must match the type of the corresponding argument in " - << "function signature(" << fnInputTypes[i] << ')'; - - return success(); -} - -/// Clone the internal blocks from this function into dest and all attributes -/// from this function to dest. -void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { - // Add the attributes of this function to dest. - llvm::MapVector newAttrMap; - for (const auto &attr : dest->getAttrs()) - newAttrMap.insert({attr.getName(), attr.getValue()}); - for (const auto &attr : (*this)->getAttrs()) - newAttrMap.insert({attr.getName(), attr.getValue()}); - - auto newAttrs = llvm::to_vector(llvm::map_range( - newAttrMap, [](std::pair attrPair) { - return NamedAttribute(attrPair.first, attrPair.second); - })); - dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs)); - - // Clone the body. - getBody().cloneInto(&dest.getBody(), mapper); -} - -/// Create a deep copy of this function and all of its blocks, remapping -/// any operands that use values outside of the function using the map that is -/// provided (leaving them alone if no entry is present). Replaces references -/// to cloned sub-values with the corresponding value that is copied, and adds -/// those mappings to the mapper. -FuncOp FuncOp::clone(BlockAndValueMapping &mapper) { - // Create the new function. - FuncOp newFunc = cast(getOperation()->cloneWithoutRegions()); - - // If the function has a body, then the user might be deleting arguments to - // the function by specifying them in the mapper. If so, we don't add the - // argument to the input type vector. - if (!isExternal()) { - FunctionType oldType = getType(); - - unsigned oldNumArgs = oldType.getNumInputs(); - SmallVector newInputs; - newInputs.reserve(oldNumArgs); - for (unsigned i = 0; i != oldNumArgs; ++i) - if (!mapper.contains(getArgument(i))) - newInputs.push_back(oldType.getInput(i)); - - /// If any of the arguments were dropped, update the type and drop any - /// necessary argument attributes. - if (newInputs.size() != oldNumArgs) { - newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, - oldType.getResults())); - - if (ArrayAttr argAttrs = getAllArgAttrs()) { - SmallVector newArgAttrs; - newArgAttrs.reserve(newInputs.size()); - for (unsigned i = 0; i != oldNumArgs; ++i) - if (!mapper.contains(getArgument(i))) - newArgAttrs.push_back(argAttrs[i]); - newFunc.setAllArgAttrs(newArgAttrs); - } - } - } - - /// Clone the current function into the new one and return it. - cloneInto(newFunc, mapper); - return newFunc; -} -FuncOp FuncOp::clone() { - BlockAndValueMapping mapper; - return clone(mapper); -} - //===----------------------------------------------------------------------===// // ModuleOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Target/LLVMIR/Dialect/All.h" #include "mlir/Target/LLVMIR/Export.h" @@ -34,7 +35,7 @@ return success(); }, [](DialectRegistry ®istry) { - registry.insert(); + registry.insert(); registerAllToLLVMIRTranslations(registry); }); } diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -14,8 +14,8 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/CallInterfaces.h" #include "llvm/ADT/MapVector.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -67,10 +67,6 @@ bool InlinerInterface::isLegalToInline( Region *dest, Region *src, bool wouldBeCloned, BlockAndValueMapping &valueMapping) const { - // Regions can always be inlined into functions. - if (isa(dest->getParentOp())) - return true; - if (auto *handler = getInterfaceFor(dest->getParentOp())) return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping); return false; diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py --- a/mlir/python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py @@ -25,208 +25,3 @@ @property def body(self): return self.regions[0].blocks[0] - - -class FuncOp: - """Specialization for the func op class.""" - - def __init__(self, - name, - type, - *, - visibility=None, - body_builder=None, - loc=None, - ip=None): - """ - Create a FuncOp with the provided `name`, `type`, and `visibility`. - - `name` is a string representing the function name. - - `type` is either a FunctionType or a pair of list describing inputs and - results. - - `visibility` is a string matching `public`, `private`, or `nested`. None - implies private visibility. - - `body_builder` is an optional callback, when provided a new entry block - is created and the callback is invoked with the new op as argument within - an InsertionPoint context already set for the block. The callback is - expected to insert a terminator in the block. - """ - sym_name = StringAttr.get(str(name)) - - # If the type is passed as a tuple, build a FunctionType on the fly. - if isinstance(type, tuple): - type = FunctionType.get(inputs=type[0], results=type[1]) - - type = TypeAttr.get(type) - sym_visibility = StringAttr.get( - str(visibility)) if visibility is not None else None - super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip) - if body_builder: - entry_block = self.add_entry_block() - with InsertionPoint(entry_block): - body_builder(self) - - @property - def is_external(self): - return len(self.regions[0].blocks) == 0 - - @property - def body(self): - return self.regions[0] - - @property - def type(self): - return FunctionType(TypeAttr(self.attributes["type"]).value) - - @property - def visibility(self): - return self.attributes["sym_visibility"] - - @property - def name(self) -> StringAttr: - return StringAttr(self.attributes["sym_name"]) - - @property - def entry_block(self): - if self.is_external: - raise IndexError('External function does not have a body') - return self.regions[0].blocks[0] - - def add_entry_block(self): - """ - Add an entry block to the function body using the function signature to - infer block arguments. - Returns the newly created block - """ - if not self.is_external: - raise IndexError('The function already has an entry block!') - self.body.blocks.append(*self.type.inputs) - return self.body.blocks[0] - - @property - def arg_attrs(self): - return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) - - @arg_attrs.setter - def arg_attrs(self, attribute: Union[ArrayAttr, list]): - if isinstance(attribute, ArrayAttr): - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute - else: - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( - attribute, context=self.context) - - @property - def arguments(self): - return self.entry_block.arguments - - @property - def result_attrs(self): - return self.attributes[RESULT_ATTRIBUTE_NAME] - - @result_attrs.setter - def result_attrs(self, attribute: ArrayAttr): - self.attributes[RESULT_ATTRIBUTE_NAME] = attribute - - @classmethod - def from_py_func(FuncOp, - *inputs: Type, - results: Optional[Sequence[Type]] = None, - name: Optional[str] = None): - """Decorator to define an MLIR FuncOp specified as a python function. - - Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are - active for the current thread (i.e. established in a `with` block). - - When applied as a decorator to a Python function, an entry block will - be constructed for the FuncOp with types as specified in `*inputs`. The - block arguments will be passed positionally to the Python function. In - addition, if the Python function accepts keyword arguments generally or - has a corresponding keyword argument, the following will be passed: - * `func_op`: The `func` op being defined. - - By default, the function name will be the Python function `__name__`. This - can be overriden by passing the `name` argument to the decorator. - - If `results` is not specified, then the decorator will implicitly - insert a `ReturnOp` with the `Value`'s returned from the decorated - function. It will also set the `FuncOp` type with the actual return - value types. If `results` is specified, then the decorated function - must return `None` and no implicit `ReturnOp` is added (nor are the result - types updated). The implicit behavior is intended for simple, single-block - cases, and users should specify result types explicitly for any complicated - cases. - - The decorated function can further be called from Python and will insert - a `CallOp` at the then-current insertion point, returning either None ( - if no return values), a unary Value (for one result), or a list of Values). - This mechanism cannot be used to emit recursive calls (by construction). - """ - - def decorator(f): - from . import func - # Introspect the callable for optional features. - sig = inspect.signature(f) - has_arg_func_op = False - for param in sig.parameters.values(): - if param.kind == param.VAR_KEYWORD: - has_arg_func_op = True - if param.name == "func_op" and (param.kind - == param.POSITIONAL_OR_KEYWORD or - param.kind == param.KEYWORD_ONLY): - has_arg_func_op = True - - # Emit the FuncOp. - implicit_return = results is None - symbol_name = name or f.__name__ - function_type = FunctionType.get( - inputs=inputs, results=[] if implicit_return else results) - func_op = FuncOp(name=symbol_name, type=function_type) - with InsertionPoint(func_op.add_entry_block()): - func_args = func_op.entry_block.arguments - func_kwargs = {} - if has_arg_func_op: - func_kwargs["func_op"] = func_op - return_values = f(*func_args, **func_kwargs) - if not implicit_return: - return_types = list(results) - assert return_values is None, ( - "Capturing a python function with explicit `results=` " - "requires that the wrapped function returns None.") - else: - # Coerce return values, add ReturnOp and rewrite func type. - if return_values is None: - return_values = [] - elif isinstance(return_values, tuple): - return_values = list(return_values) - elif isinstance(return_values, Value): - # Returning a single value is fine, coerce it into a list. - return_values = [return_values] - elif isinstance(return_values, OpView): - # Returning a single operation is fine, coerce its results a list. - return_values = return_values.operation.results - elif isinstance(return_values, Operation): - # Returning a single operation is fine, coerce its results a list. - return_values = return_values.results - else: - return_values = list(return_values) - func.ReturnOp(return_values) - # Recompute the function type. - return_types = [v.type for v in return_values] - function_type = FunctionType.get(inputs=inputs, results=return_types) - func_op.attributes["type"] = TypeAttr.get(function_type) - - def emit_call_op(*call_args): - call_op = func.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name), - call_args) - if return_types is None: - return None - elif len(return_types) == 1: - return call_op.result - else: - return call_op.results - - wrapped = emit_call_op - wrapped.__name__ = f.__name__ - wrapped.func_op = func_op - return wrapped - - return decorator diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py --- a/mlir/python/mlir/dialects/_func_ops_ext.py +++ b/mlir/python/mlir/dialects/_func_ops_ext.py @@ -45,7 +45,7 @@ For example - f = builtin.FuncOp("foo", ...) + f = func.FuncOp("foo", ...) func.CallOp(f, [args]) func.CallOp([result_types], "foo", [args]) @@ -91,3 +91,207 @@ arguments, loc=loc, ip=ip) + +class FuncOp: + """Specialization for the func op class.""" + + def __init__(self, + name, + type, + *, + visibility=None, + body_builder=None, + loc=None, + ip=None): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = StringAttr.get( + str(visibility)) if visibility is not None else None + super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) + + @property + def entry_block(self): + if self.is_external: + raise IndexError('External function does not have a body') + return self.regions[0].blocks[0] + + def add_entry_block(self): + """ + Add an entry block to the function body using the function signature to + infer block arguments. + Returns the newly created block + """ + if not self.is_external: + raise IndexError('The function already has an entry block!') + self.body.blocks.append(*self.type.inputs) + return self.body.blocks[0] + + @property + def arg_attrs(self): + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) + + @arg_attrs.setter + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + else: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context) + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute + + @classmethod + def from_py_func(FuncOp, + *inputs: Type, + results: Optional[Sequence[Type]] = None, + name: Optional[str] = None): + """Decorator to define an MLIR FuncOp specified as a python function. + + Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are + active for the current thread (i.e. established in a `with` block). + + When applied as a decorator to a Python function, an entry block will + be constructed for the FuncOp with types as specified in `*inputs`. The + block arguments will be passed positionally to the Python function. In + addition, if the Python function accepts keyword arguments generally or + has a corresponding keyword argument, the following will be passed: + * `func_op`: The `func` op being defined. + + By default, the function name will be the Python function `__name__`. This + can be overriden by passing the `name` argument to the decorator. + + If `results` is not specified, then the decorator will implicitly + insert a `ReturnOp` with the `Value`'s returned from the decorated + function. It will also set the `FuncOp` type with the actual return + value types. If `results` is specified, then the decorated function + must return `None` and no implicit `ReturnOp` is added (nor are the result + types updated). The implicit behavior is intended for simple, single-block + cases, and users should specify result types explicitly for any complicated + cases. + + The decorated function can further be called from Python and will insert + a `CallOp` at the then-current insertion point, returning either None ( + if no return values), a unary Value (for one result), or a list of Values). + This mechanism cannot be used to emit recursive calls (by construction). + """ + + def decorator(f): + from . import func + # Introspect the callable for optional features. + sig = inspect.signature(f) + has_arg_func_op = False + for param in sig.parameters.values(): + if param.kind == param.VAR_KEYWORD: + has_arg_func_op = True + if param.name == "func_op" and (param.kind + == param.POSITIONAL_OR_KEYWORD or + param.kind == param.KEYWORD_ONLY): + has_arg_func_op = True + + # Emit the FuncOp. + implicit_return = results is None + symbol_name = name or f.__name__ + function_type = FunctionType.get( + inputs=inputs, results=[] if implicit_return else results) + func_op = FuncOp(name=symbol_name, type=function_type) + with InsertionPoint(func_op.add_entry_block()): + func_args = func_op.entry_block.arguments + func_kwargs = {} + if has_arg_func_op: + func_kwargs["func_op"] = func_op + return_values = f(*func_args, **func_kwargs) + if not implicit_return: + return_types = list(results) + assert return_values is None, ( + "Capturing a python function with explicit `results=` " + "requires that the wrapped function returns None.") + else: + # Coerce return values, add ReturnOp and rewrite func type. + if return_values is None: + return_values = [] + elif isinstance(return_values, tuple): + return_values = list(return_values) + elif isinstance(return_values, Value): + # Returning a single value is fine, coerce it into a list. + return_values = [return_values] + elif isinstance(return_values, OpView): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.operation.results + elif isinstance(return_values, Operation): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.results + else: + return_values = list(return_values) + func.ReturnOp(return_values) + # Recompute the function type. + return_types = [v.type for v in return_values] + function_type = FunctionType.get(inputs=inputs, results=return_types) + func_op.attributes["type"] = TypeAttr.get(function_type) + + def emit_call_op(*call_args): + call_op = func.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name), + call_args) + if return_types is None: + return None + elif len(return_types) == 1: + return call_op.result + else: + return call_op.results + + wrapped = emit_call_op + wrapped.__name__ = f.__name__ + wrapped.func_op = func_op + return wrapped + + return decorator diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp --- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -23,8 +24,6 @@ using namespace mlir; -static llvm::cl::OptionCategory clOptionsCategory(PASS_NAME " options"); - namespace { struct TestAffineDataCopy diff --git a/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp b/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp --- a/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" using namespace mlir; diff --git a/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp b/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp --- a/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp +++ b/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopFusionUtils.h" #include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" #define DEBUG_TYPE "test-loop-fusion" diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp --- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -13,6 +13,7 @@ #include #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/Transforms.h" diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgHoisting.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgHoisting.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgHoisting.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgHoisting.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/SCF/Utils/Utils.h" diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -18,6 +18,7 @@ #include "TestInterfaces.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/DLTI/Traits.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/BuiltinOps.h" diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp --- a/mlir/unittests/ExecutionEngine/Invoke.cpp +++ b/mlir/unittests/ExecutionEngine/Invoke.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp --- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp +++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp @@ -278,7 +278,7 @@ } static unsigned getNameLengthPlusArgTwice(unsigned arg) { - return FuncOp::getOperationName().size() + 2 * arg; + return UnrealizedConversionCastOp::getOperationName().size() + 2 * arg; } unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const { @@ -290,9 +290,11 @@ TEST(InterfaceAttachment, Operation) { MLIRContext context; + OpBuilder builder(&context); // Initially, the operation doesn't have the interface. - OwningOpRef moduleOp = ModuleOp::create(UnknownLoc::get(&context)); + OwningOpRef moduleOp = + builder.create(UnknownLoc::get(&context)); ASSERT_FALSE(isa(moduleOp->getOperation())); // We can attach an external interface and now the operaiton has it. @@ -305,16 +307,17 @@ EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u); // Default implementation can be overridden. - OwningOpRef funcOp = - FuncOp::create(UnknownLoc::get(&context), "function", - FunctionType::get(&context, {}, {})); - ASSERT_FALSE(isa(funcOp->getOperation())); - FuncOp::attachInterface(context); - iface = dyn_cast(funcOp->getOperation()); + OwningOpRef castOp = + builder.create(UnknownLoc::get(&context), + TypeRange(), ValueRange()); + ASSERT_FALSE(isa(castOp->getOperation())); + UnrealizedConversionCastOp::attachInterface( + context); + iface = dyn_cast(castOp->getOperation()); ASSERT_TRUE(iface != nullptr); - EXPECT_EQ(iface.getNameLengthPlusArg(10), 22u); + EXPECT_EQ(iface.getNameLengthPlusArg(10), 44u); EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u); - EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 28u); + EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 50u); EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u); // Another context doesn't have the interfaces registered. diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp --- a/mlir/unittests/Pass/AnalysisManagerTest.cpp +++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Pass/AnalysisManager.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" @@ -51,6 +52,7 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) { MLIRContext context; + context.loadDialect(); Builder builder(&context); // Create a function and a module. @@ -81,6 +83,7 @@ TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) { MLIRContext context; + context.loadDialect(); Builder builder(&context); // Create a function and a module. diff --git a/mlir/unittests/Pass/CMakeLists.txt b/mlir/unittests/Pass/CMakeLists.txt --- a/mlir/unittests/Pass/CMakeLists.txt +++ b/mlir/unittests/Pass/CMakeLists.txt @@ -5,4 +5,5 @@ ) target_link_libraries(MLIRPassTests PRIVATE + MLIRFunc MLIRPass) diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp --- a/mlir/unittests/Pass/PassManagerTest.cpp +++ b/mlir/unittests/Pass/PassManagerTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Pass/PassManager.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" @@ -47,6 +48,7 @@ TEST(PassManagerTest, OpSpecificAnalysis) { MLIRContext context; + context.loadDialect(); Builder builder(&context); // Create a module with 2 functions.