diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h @@ -16,6 +16,8 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/IR/FunctionInterfaces.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -14,6 +14,8 @@ #define MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS include "mlir/Dialect/PDL/IR/PDLTypes.td" +include "mlir/IR/FunctionInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// @@ -627,6 +629,64 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// pdl_interp::FuncOp +//===----------------------------------------------------------------------===// + +def PDLInterp_FuncOp : PDLInterp_Op<"func", [ + FunctionOpInterface, IsolatedFromAbove, Symbol + ]> { + let summary = "PDL Interpreter Function Operation"; + let description = [{ + `pdl_interp.func` operations act as interpreter functions. These are + callable SSA-region operations that contain other interpreter operations. + Interpreter functions are used for both the matching and the rewriting + portion of the interpreter. + + Example: + + ```mlir + pdl_interp.func @rewriter(%root: !pdl.operation) { + %op = pdl_interp.create_operation "foo.new_operation" + pdl_interp.erase %root + pdl_interp.finalize + } + ``` + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttrOf:$type + ); + let regions = (region MinSizedRegion<1>:$body); + + // Create the function with the given name and type. This also automatically + // inserts the entry block for the function. + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs) + >]; + let extraClassDeclaration = [{ + /// Returns the type of this function. + /// FIXME: We should drive this via the ODS `type` param. + FunctionType getType() { + return getTypeAttr().getValue().cast(); + } + + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return type().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return type().getResults(); } + }]; + let hasCustomAssemblyFormat = 1; + let skipDefaultBuilders = 1; +} + //===----------------------------------------------------------------------===// // pdl_interp::GetAttributeOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -133,16 +133,6 @@ /// Returns the result types of this function. ArrayRef getResultTypes() { return getType().getResults(); } - /// Verify the type attribute of this function. Returns failure and emits - /// an error if the attribute is invalid. - LogicalResult verifyType() { - auto type = getTypeAttr().getValue(); - if (!type.isa()) - return emitOpError("requires '" + FunctionOpInterface::getTypeAttrName() + - "' attribute of function type"); - return success(); - } - //===------------------------------------------------------------------===// // SymbolOpInterface Methods //===------------------------------------------------------------------===// @@ -150,7 +140,6 @@ bool isDeclaration() { return isExternal(); } }]; let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h --- a/mlir/include/mlir/IR/FunctionImplementation.h +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -86,10 +86,8 @@ bool allowVariadic, FuncTypeBuilder funcTypeBuilder); -/// Printer implementation for function-like operations. Accepts lists of -/// argument and result types to use while printing. -void printFunctionOp(OpAsmPrinter &p, Operation *op, ArrayRef argTypes, - bool isVariadic, ArrayRef resultTypes); +/// Printer implementation for function-like operations. +void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic); /// Prints the signature of the function-like operation `op`. Assumes `op` has /// is a FunctionOpInterface and has passed verification. diff --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h --- a/mlir/include/mlir/IR/FunctionInterfaces.h +++ b/mlir/include/mlir/IR/FunctionInterfaces.h @@ -14,6 +14,7 @@ #ifndef MLIR_IR_FUNCTIONINTERFACES_H #define MLIR_IR_FUNCTIONINTERFACES_H +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/SymbolTable.h" diff --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td --- a/mlir/include/mlir/IR/FunctionInterfaces.td +++ b/mlir/include/mlir/IR/FunctionInterfaces.td @@ -79,25 +79,41 @@ InterfaceMethod<[{ Verify the contents of the body of this function. - Note: The default implementation merely checks that of the entry block - exists, it has the same number arguments as the function type. + Note: The default implementation merely checks that if the entry block + exists, it has the same number and type of arguments as the function type. }], "::mlir::LogicalResult", "verifyBody", (ins), /*methodBody=*/[{}], /*defaultImplementation=*/[{ if ($_op.isExternal()) return success(); - - unsigned numArguments = $_op.getNumArguments(); - if ($_op.front().getNumArguments() != numArguments) + ArrayRef fnInputTypes = $_op.getArgumentTypes(); + Block &entryBlock = $_op.front(); + + unsigned numArguments = fnInputTypes.size(); + if (entryBlock.getNumArguments() != numArguments) return $_op.emitOpError("entry block must have ") << numArguments << " arguments to match function signature"; + + for (unsigned i = 0, e = fnInputTypes.size(); i != e; ++i) { + Type argType = entryBlock.getArgument(i).getType(); + if (fnInputTypes[i] != argType) { + return $_op.emitOpError("type of entry block argument #") + << i << '(' << argType + << ") must match the type of the corresponding argument in " + << "function signature(" << fnInputTypes[i] << ')'; + } + } + return success(); }]>, InterfaceMethod<[{ Verify the type attribute of the function for derived op-specific invariants. }], - "::mlir::LogicalResult", "verifyType">, + "::mlir::LogicalResult", "verifyType", (ins), + /*methodBody=*/[{}], /*defaultImplementation=*/[{ + return success(); + }]>, ]; let extraClassDeclaration = [{ @@ -108,6 +124,31 @@ /// Return the name of the function. StringRef getName() { return SymbolTable::getSymbolName(*this); } }]; + let extraTraitClassDeclaration = [{ + //===------------------------------------------------------------------===// + // Builders + //===------------------------------------------------------------------===// + + /// Build the function with the given name, attributes, and type. This + /// builder also inserts an entry block into the function body with the + /// given argument types. + static void buildWithEntryBlock( + OpBuilder &builder, OperationState &state, StringRef name, Type type, + ArrayRef attrs, ArrayRef inputTypes) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(function_interface_impl::getTypeAttrName(), + TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + + // Add the function body. + Region *bodyRegion = state.addRegion(); + Block *body = new Block(); + bodyRegion->push_back(body); + for (Type input : inputTypes) + body->addArgument(input, state.location); + } + }]; let extraSharedClassDeclaration = [{ /// Block list iterator types. using BlockListType = Region::BlockListType; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1942,6 +1942,11 @@ CPred<"::llvm::hasNItems($_self, " # numBlocks # ")">, "region with " # numBlocks # " blocks">; +// A region with at least the given number of blocks. +class MinSizedRegion : Region< + CPred<"::llvm::hasNItemsOrMore($_self, " # numBlocks # ")">, + "region with at least " # numBlocks # " blocks">; + // A variadic region constraint. It expands to zero or more of the base region. class VariadicRegion : Region; diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -32,7 +32,7 @@ /// given module containing PDL pattern operations. struct PatternLowering { public: - PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule); + PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule); /// Generate code for matching and rewriting based on the pattern operations /// within the module. @@ -110,7 +110,7 @@ OpBuilder builder; /// The matcher function used for all match related logic within PDL patterns. - FuncOp matcherFunc; + pdl_interp::FuncOp matcherFunc; /// The rewriter module containing the all rewrite related logic within PDL /// patterns. @@ -137,7 +137,8 @@ }; } // namespace -PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule) +PatternLowering::PatternLowering(pdl_interp::FuncOp matcherFunc, + ModuleOp rewriterModule) : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {} @@ -150,7 +151,7 @@ // Insert the root operation, i.e. argument to the matcher, at the root // position. - Block *matcherEntryBlock = matcherFunc.addEntryBlock(); + Block *matcherEntryBlock = &matcherFunc.front(); values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); // Generate a root matcher node from the provided PDL module. @@ -590,13 +591,14 @@ SymbolRefAttr PatternLowering::generateRewriter( pdl::PatternOp pattern, SmallVectorImpl &usedMatchValues) { - FuncOp rewriterFunc = - FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter", - builder.getFunctionType(llvm::None, llvm::None)); + builder.setInsertionPointToEnd(rewriterModule.getBody()); + auto rewriterFunc = builder.create( + pattern.getLoc(), "pdl_generated_rewriter", + builder.getFunctionType(llvm::None, llvm::None)); rewriterSymbolTable.insert(rewriterFunc); // Generate the rewriter function body. - builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock()); + builder.setInsertionPointToEnd(&rewriterFunc.front()); // Map an input operand of the pattern to a generated interpreter value. DenseMap rewriteValues; @@ -902,7 +904,7 @@ // Create the main matcher function This function contains all of the match // related functionality from patterns in the module. OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); - FuncOp matcherFunc = builder.create( + auto matcherFunc = builder.create( module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), builder.getFunctionType(builder.getType(), /*results=*/llvm::None), diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1997,8 +1997,8 @@ // Returns a null type if any of the types provided are non-LLVM types, or if // there is more than one output type. static Type -buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, - ArrayRef inputs, ArrayRef outputs, +buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef inputs, + ArrayRef outputs, function_interface_impl::VariadicFlag variadicFlag) { Builder &b = parser.getBuilder(); if (outputs.size() > 1) { @@ -2159,7 +2159,7 @@ } /// Verifies LLVM- and implementation-specific properties of the LLVM func Op: -/// - entry block arguments are of LLVM types and match the function signature. +/// - entry block arguments are of LLVM types. LogicalResult LLVMFuncOp::verifyRegions() { if (isExternal()) return success(); @@ -2171,9 +2171,6 @@ if (!isCompatibleType(argType)) return emitOpError("entry block argument #") << i << " is not of LLVM type"; - if (getType().getParamType(i) != argType) - return emitOpError("the type of entry block argument #") - << i << " does not match the function signature"; } return success(); diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/FunctionImplementation.h" using namespace mlir; using namespace mlir::pdl_interp; @@ -161,6 +162,29 @@ return success(); } +//===----------------------------------------------------------------------===// +// pdl_interp::FuncOp +//===----------------------------------------------------------------------===// + +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs) { + buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); +} + +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, buildFuncType); +} + +void FuncOp::print(OpAsmPrinter &p) { + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); +} + //===----------------------------------------------------------------------===// // pdl_interp::GetValueTypeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -124,29 +124,7 @@ } void FuncOp::print(OpAsmPrinter &p) { - FunctionType fnType = getType(); - function_interface_impl::printFunctionOp( - p, *this, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults()); -} - -LogicalResult FuncOp::verify() { - // If this function is external there is nothing to do. - if (isExternal()) - return success(); - - // Verify that the argument list of the function and the arg list of the entry - // block line up. The trait already verified that the number of arguments is - // the same between the signature and the block. - auto fnInputTypes = getType().getInputs(); - Block &entryBlock = front(); - for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i) - if (fnInputTypes[i] != entryBlock.getArgument(i).getType()) - return emitOpError("type of entry block argument #") - << i << '(' << entryBlock.getArgument(i).getType() - << ") must match the type of the corresponding argument in " - << "function signature(" << fnInputTypes[i] << ')'; - - return success(); + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); } /// Clone the internal blocks from this function into dest and all attributes diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -344,9 +344,9 @@ p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); } -void mlir::function_interface_impl::printFunctionOp( - OpAsmPrinter &p, Operation *op, ArrayRef argTypes, bool isVariadic, - ArrayRef resultTypes) { +void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p, + FunctionOpInterface op, + bool isVariadic) { // Print the operation and the function name. auto funcName = op->getAttrOfType(SymbolTable::getSymbolAttrName()) @@ -358,6 +358,8 @@ p << visibility.getValue() << ' '; p.printSymbolName(funcName); + ArrayRef argTypes = op.getArgumentTypes(); + ArrayRef resultTypes = op.getResultTypes(); printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(), {visibilityAttrName}); diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -239,7 +239,8 @@ private: /// Allocate memory indices for the results of operations within the matcher /// and rewriters. - void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); + void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, + ModuleOp rewriterModule); /// Generate the bytecode for the given operation. void generate(Region *region, ByteCodeWriter &writer); @@ -482,7 +483,7 @@ } // namespace void Generator::generate(ModuleOp module) { - FuncOp matcherFunc = module.lookupSymbol( + auto matcherFunc = module.lookupSymbol( pdl_interp::PDLInterpDialect::getMatcherFunctionName()); ModuleOp rewriterModule = module.lookupSymbol( pdl_interp::PDLInterpDialect::getRewriterModuleName()); @@ -494,7 +495,7 @@ // Generate code for the rewriter functions. ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); - for (FuncOp rewriterFunc : rewriterModule.getOps()) { + for (auto rewriterFunc : rewriterModule.getOps()) { rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); for (Operation &op : rewriterFunc.getOps()) generate(&op, rewriterByteCodeWriter); @@ -514,11 +515,11 @@ } } -void Generator::allocateMemoryIndices(FuncOp matcherFunc, +void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule) { // Rewriters use simplistic allocation scheme that simply assigns an index to // each result. - for (FuncOp rewriterFunc : rewriterModule.getOps()) { + for (auto rewriterFunc : rewriterModule.getOps()) { ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; auto processRewriterValue = [&](Value val) { valueToMemIndex.try_emplace(val, index++); diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -180,7 +180,7 @@ // ----- module { - // expected-error@+1 {{entry block argument #0 is not of LLVM type}} + // expected-error@+1 {{entry block argument #0('tensor<*xf32>') must match the type of the corresponding argument in function signature('i64')}} "llvm.func"() ({ ^bb0(%arg0: tensor<*xf32>): llvm.return @@ -189,16 +189,6 @@ // ----- -module { - // expected-error@+1 {{entry block argument #0 does not match the function signature}} - "llvm.func"() ({ - ^bb0(%arg0: i32): - llvm.return - }) {sym_name = "wrong_arg_number", type = !llvm.func} : () -> () -} - -// ----- - module { // expected-error@+1 {{failed to construct function type: expected LLVM type for function arguments}} llvm.func @foo(tensor<*xf32>) diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -10,7 +10,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.apply_constraint "multi_entity_constraint"(%root, %root : !pdl.operation, !pdl.operation) -> ^pat, ^end ^pat: @@ -24,7 +24,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.replaced_by_pattern" pdl_interp.erase %root pdl_interp.finalize @@ -41,7 +41,7 @@ // ----- module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { %results = pdl_interp.get_results of %root : !pdl.range %types = pdl_interp.get_value_type of %results : !pdl.range pdl_interp.apply_constraint "multi_entity_var_constraint"(%results, %types : !pdl.range, !pdl.range) -> ^pat, ^end @@ -54,7 +54,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.replaced_by_pattern" pdl_interp.erase %root pdl_interp.finalize @@ -77,7 +77,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end ^pat: @@ -88,7 +88,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %operand = pdl_interp.get_operand 0 of %root pdl_interp.apply_rewrite "rewriter"[42](%root, %operand : !pdl.operation, !pdl.value) pdl_interp.finalize @@ -108,7 +108,7 @@ // ----- module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end ^pat: @@ -119,7 +119,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.apply_rewrite "creator"(%root : !pdl.operation) : !pdl.operation pdl_interp.erase %root pdl_interp.finalize @@ -136,7 +136,7 @@ // ----- module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end ^pat: @@ -147,7 +147,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %operands, %types = pdl_interp.apply_rewrite "var_creator"(%root : !pdl.operation) : !pdl.range, !pdl.range %op = pdl_interp.create_operation "test.success"(%operands : !pdl.range) -> (%types : !pdl.range) pdl_interp.replace %root with (%operands : !pdl.range) @@ -169,7 +169,7 @@ // ----- module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end ^pat: @@ -180,7 +180,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %type = pdl_interp.apply_rewrite "type_creator" : !pdl.type %newOp = pdl_interp.create_operation "test.success" -> (%type : !pdl.type) pdl_interp.erase %root @@ -202,7 +202,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { %test_attr = pdl_interp.create_attribute unit %attr = pdl_interp.get_attribute "test_attr" of %root pdl_interp.are_equal %test_attr, %attr : !pdl.attribute -> ^pat, ^end @@ -215,7 +215,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -232,7 +232,7 @@ // ----- module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { %const_types = pdl_interp.create_types [i32, i64] %results = pdl_interp.get_results of %root : !pdl.range %result_types = pdl_interp.get_value_type of %results : !pdl.range @@ -246,7 +246,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -270,7 +270,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_operation_name of %root is "test.op" -> ^pat1, ^end ^pat1: @@ -284,7 +284,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -305,7 +305,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { %attr = pdl_interp.get_attribute "test_attr" of %root pdl_interp.check_attribute %attr is unit -> ^pat, ^end @@ -317,7 +317,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -338,7 +338,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_operand_count of %root is at_least 1 -> ^exact_check, ^end ^exact_check: @@ -352,7 +352,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -375,7 +375,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end ^pat: @@ -386,7 +386,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -407,7 +407,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_result_count of %root is at_least 1 -> ^exact_check, ^end ^exact_check: @@ -421,7 +421,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -445,7 +445,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { %attr = pdl_interp.get_attribute "test_attr" of %root pdl_interp.is_not_null %attr : !pdl.attribute -> ^pat1, ^end @@ -461,7 +461,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -482,7 +482,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { %results = pdl_interp.get_results of %root : !pdl.range %result_types = pdl_interp.get_value_type of %results : !pdl.range pdl_interp.check_types %result_types are [i32] -> ^pat2, ^end @@ -495,7 +495,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -537,7 +537,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { %attr = pdl_interp.get_attribute "test_attr" of %root pdl_interp.is_not_null %attr : !pdl.attribute -> ^pat1, ^end @@ -554,7 +554,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -587,7 +587,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { %val = pdl_interp.get_result 0 of %root %ops = pdl_interp.get_users of %val : !pdl.value %op1 = pdl_interp.extract 1 of %ops : !pdl.operation @@ -599,7 +599,7 @@ } module @rewriters { - func @success(%matched : !pdl.operation) { + pdl_interp.func @success(%matched : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %matched pdl_interp.finalize @@ -620,7 +620,7 @@ // ----- module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { %vals = pdl_interp.get_results of %root : !pdl.range %types = pdl_interp.get_value_type of %vals : !pdl.range %type1 = pdl_interp.extract 1 of %types : !pdl.type @@ -632,7 +632,7 @@ } module @rewriters { - func @success(%matched : !pdl.operation) { + pdl_interp.func @success(%matched : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %matched pdl_interp.finalize @@ -653,7 +653,7 @@ // ----- module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { %vals = pdl_interp.get_results of %root : !pdl.range %val1 = pdl_interp.extract 1 of %vals : !pdl.value pdl_interp.is_not_null %val1 : !pdl.value -> ^success, ^end @@ -664,7 +664,7 @@ } module @rewriters { - func @success(%matched : !pdl.operation) { + pdl_interp.func @success(%matched : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %matched pdl_interp.finalize @@ -695,7 +695,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { %val1 = pdl_interp.get_result 0 of %root %ops1 = pdl_interp.get_users of %val1 : !pdl.value pdl_interp.foreach %op1 : !pdl.operation in %ops1 { @@ -714,7 +714,7 @@ } module @rewriters { - func @success(%matched : !pdl.operation) { + pdl_interp.func @success(%matched : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %matched pdl_interp.finalize @@ -747,7 +747,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { %val = pdl_interp.get_result 0 of %root %ops = pdl_interp.get_users of %val : !pdl.value pdl_interp.foreach %op : !pdl.operation in %ops { @@ -760,7 +760,7 @@ } module @rewriters { - func @success(%matched : !pdl.operation) { + pdl_interp.func @success(%matched : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %matched pdl_interp.finalize @@ -781,7 +781,7 @@ // ----- module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_result_count of %root is at_least 2 -> ^next, ^end ^next: %vals = pdl_interp.get_results of %root : !pdl.range @@ -796,7 +796,7 @@ } module @rewriters { - func @success(%matched : !pdl.operation) { + pdl_interp.func @success(%matched : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %matched pdl_interp.finalize @@ -817,7 +817,7 @@ // ----- module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_result_count of %root is at_least 2 -> ^next, ^end ^next: %vals = pdl_interp.get_results of %root : !pdl.range @@ -833,7 +833,7 @@ } module @rewriters { - func @success(%matched : !pdl.operation) { + pdl_interp.func @success(%matched : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %matched pdl_interp.finalize @@ -870,7 +870,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_operand_count of %root is 5 -> ^pat1, ^end ^pat1: @@ -888,7 +888,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -921,7 +921,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_operand_count of %root is 2 -> ^pat1, ^end ^pat1: @@ -937,7 +937,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -956,7 +956,7 @@ // Test all of the various combinations related to `AttrSizedOperandSegments`. module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_operation_name of %root is "test.attr_sized_operands" -> ^pat1, ^end ^pat1: @@ -996,7 +996,7 @@ } module @rewriters { - func @success(%root: !pdl.operation, %operands_0: !pdl.range, %operands_1: !pdl.range, %operands_2: !pdl.range, %operands_2_single: !pdl.value) { + pdl_interp.func @success(%root: !pdl.operation, %operands_0: !pdl.range, %operands_1: !pdl.range, %operands_2: !pdl.range, %operands_2_single: !pdl.value) { %op0 = pdl_interp.create_operation "test.success"(%operands_0 : !pdl.range) %op1 = pdl_interp.create_operation "test.success"(%operands_1 : !pdl.range) %op2 = pdl_interp.create_operation "test.success"(%operands_2 : !pdl.range) @@ -1025,7 +1025,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_result_count of %root is 5 -> ^pat1, ^end ^pat1: @@ -1043,7 +1043,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -1066,7 +1066,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_result_count of %root is 5 -> ^pat1, ^end ^pat1: @@ -1082,7 +1082,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -1100,7 +1100,7 @@ // Test all of the various combinations related to `AttrSizedResultSegments`. module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_operation_name of %root is "test.attr_sized_results" -> ^pat1, ^end ^pat1: @@ -1140,7 +1140,7 @@ } module @rewriters { - func @success(%root: !pdl.operation, %results_0: !pdl.range, %results_1: !pdl.range, %results_2: !pdl.range, %results_2_single: !pdl.value) { + pdl_interp.func @success(%root: !pdl.operation, %results_0: !pdl.range, %results_1: !pdl.range, %results_2: !pdl.range, %results_2_single: !pdl.value) { %results_0_types = pdl_interp.get_value_type of %results_0 : !pdl.range %results_1_types = pdl_interp.get_value_type of %results_1 : !pdl.range %results_2_types = pdl_interp.get_value_type of %results_2 : !pdl.range @@ -1198,7 +1198,7 @@ // Check that the highest benefit pattern is selected. module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_operation_name of %root is "test.op" -> ^pat1, ^end ^pat1: @@ -1212,11 +1212,11 @@ } module @rewriters { - func @failure(%root : !pdl.operation) { + pdl_interp.func @failure(%root : !pdl.operation) { pdl_interp.erase %root pdl_interp.finalize } - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -1234,7 +1234,7 @@ // Check that ranges are properly forwarded to the result. module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_operation_name of %root is "test.op" -> ^pat1, ^end ^pat1: @@ -1248,7 +1248,7 @@ } module @rewriters { - func @success(%operands: !pdl.range, %types: !pdl.range, %root: !pdl.operation) { + pdl_interp.func @success(%operands: !pdl.range, %types: !pdl.range, %root: !pdl.operation) { %op = pdl_interp.create_operation "test.success"(%operands : !pdl.range) -> (%types : !pdl.range) %results = pdl_interp.get_results of %op : !pdl.range pdl_interp.replace %root with (%results : !pdl.range) @@ -1274,7 +1274,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end ^pat: @@ -1285,7 +1285,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %operand = pdl_interp.get_operand 0 of %root pdl_interp.replace %root with (%operand : !pdl.value) pdl_interp.finalize @@ -1310,7 +1310,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { %attr = pdl_interp.get_attribute "test_attr" of %root pdl_interp.switch_attribute %attr to [0, unit](^end, ^pat) -> ^end @@ -1326,7 +1326,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -1347,7 +1347,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.switch_operand_count of %root to dense<[0, 1]> : vector<2xi32>(^end, ^pat) -> ^end ^pat: @@ -1361,7 +1361,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -1383,7 +1383,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.switch_operation_name of %root to ["foo.op", "test.op"](^end, ^pat1) -> ^end ^pat1: @@ -1397,7 +1397,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -1418,7 +1418,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { pdl_interp.switch_result_count of %root to dense<[0, 1]> : vector<2xi32>(^end, ^pat) -> ^end ^pat: @@ -1432,7 +1432,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -1453,7 +1453,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { %attr = pdl_interp.get_attribute "test_attr" of %root pdl_interp.is_not_null %attr : !pdl.attribute -> ^pat1, ^end @@ -1472,7 +1472,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize @@ -1493,7 +1493,7 @@ //===----------------------------------------------------------------------===// module @patterns { - func @matcher(%root : !pdl.operation) { + pdl_interp.func @matcher(%root : !pdl.operation) { %results = pdl_interp.get_results of %root : !pdl.range %types = pdl_interp.get_value_type of %results : !pdl.range pdl_interp.switch_types %types to [[i64, i64], [i32]](^pat2, ^end) -> ^end @@ -1509,7 +1509,7 @@ } module @rewriters { - func @success(%root : !pdl.operation) { + pdl_interp.func @success(%root : !pdl.operation) { %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize