diff --git a/mlir/include/mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h b/mlir/include/mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h @@ -0,0 +1,28 @@ +//===- PDLToPDLInterp.h - PDL to PDL Interpreter conversion -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides a pass for PDL to PDL Interpreter dialect conversion. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H +#define MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H + +#include + +namespace mlir { +class ModuleOp; +template +class OperationPass; + +/// Creates and returns a pass to convert PDL ops to PDL interpreter ops. +std::unique_ptr> createPDLToPDLInterpPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -20,6 +20,7 @@ #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h" #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" +#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -198,6 +198,16 @@ let dependentDialects = ["LLVM::LLVMDialect"]; } +//===----------------------------------------------------------------------===// +// PDLToPDLInterp +//===----------------------------------------------------------------------===// + +def ConvertPDLToPDLInterp : Pass<"convert-pdl-to-pdl-interp", "ModuleOp"> { + let summary = "Convert PDL ops to PDL interpreter ops"; + let constructor = "mlir::createPDLToPDLInterpPass()"; + let dependentDialects = ["pdl_interp::PDLInterpDialect"]; +} + //===----------------------------------------------------------------------===// // SCFToStandard //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -36,6 +36,16 @@ let name = "pdl_interp"; let cppNamespace = "::mlir::pdl_interp"; let dependentDialects = ["pdl::PDLDialect"]; + let extraClassDeclaration = [{ + /// Returns the name of the function containing the matcher code. This + /// function is called by the interpreter when matching an operation. + static StringRef getMatcherFunctionName() { return "matcher"; } + + /// Returns the name of the module containing the rewrite functions. These + /// functions are invoked by distinct patterns within the matcher function + /// to rewrite the IR after a successful match. + static StringRef getRewriterModuleName() { return "rewriters"; } + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h --- a/mlir/include/mlir/Support/StorageUniquer.h +++ b/mlir/include/mlir/Support/StorageUniquer.h @@ -157,8 +157,7 @@ } /// Utility override when the storage type represents the type id. template - void registerSingletonStorageType( - function_ref initFn = llvm::None) { + void registerSingletonStorageType(function_ref initFn = {}) { registerSingletonStorageType(TypeID::get(), initFn); } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -0,0 +1,685 @@ +//===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" +#include "../PassDetail.h" +#include "PredicateTree.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::pdl_to_pdl_interp; + +//===----------------------------------------------------------------------===// +// PatternLowering +//===----------------------------------------------------------------------===// + +namespace { +/// This class generators operations within the PDL Interpreter dialect from a +/// given module containing PDL pattern operations. +struct PatternLowering { +public: + PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule); + + /// Generate code for matching and rewriting based on the pattern operations + /// within the module. + void lower(ModuleOp module); + +private: + using ValueMap = llvm::ScopedHashTable; + using ValueMapScope = llvm::ScopedHashTableScope; + + /// Generate interpreter operations for the tree rooted at the given matcher + /// node. + Block *generateMatcher(MatcherNode &node); + + /// Get or create an access to the provided positional value within the + /// current block. + Value getValueAt(Block *cur, Position *pos); + + /// Create an interpreter predicate operation, branching to the provided true + /// and false destinations. + void generatePredicate(Block *cur, Qualifier *question, Qualifier *answer, + Value val, Block *trueDest, Block *falseDest); + + /// Create an interpreter switch predicate operation, with a provided default + /// and several case destinations. + void generateSwitch(Block *cur, Qualifier *question, Value val, + Block *defaultDest, + ArrayRef> dests); + + /// Generate a rewriter function for the given pattern operation, and returns + /// a reference to that function. + SymbolRefAttr generateRewriter(pdl::PatternOp pattern, + SmallVectorImpl &usedMatchValues); + + /// Generate the rewriter code for the given operation. + void generateRewriter(pdl::AttributeOp attrOp, + DenseMap &rewriteValues, + function_ref mapRewriteValue); + void generateRewriter(pdl::EraseOp eraseOp, + DenseMap &rewriteValues, + function_ref mapRewriteValue); + void generateRewriter(pdl::OperationOp operationOp, + DenseMap &rewriteValues, + function_ref mapRewriteValue); + void generateRewriter(pdl::CreateNativeOp createNativeOp, + DenseMap &rewriteValues, + function_ref mapRewriteValue); + void generateRewriter(pdl::ReplaceOp replaceOp, + DenseMap &rewriteValues, + function_ref mapRewriteValue); + void generateRewriter(pdl::TypeOp typeOp, + DenseMap &rewriteValues, + function_ref mapRewriteValue); + + /// Generate the values used for resolving the result types of an operation + /// created within a dag rewriter region. + void generateOperationResultTypeRewriter( + pdl::OperationOp op, SmallVectorImpl &types, + DenseMap &rewriteValues, + function_ref mapRewriteValue); + + /// A builder to use when generating interpreter operations. + OpBuilder builder; + + /// The matcher function used for all match related logic within PDL patterns. + FuncOp matcherFunc; + + /// The rewriter module containing the all rewrite related logic within PDL + /// patterns. + ModuleOp rewriterModule; + + /// The symbol table of the rewriter module used for insertion. + SymbolTable rewriterSymbolTable; + + /// A scoped map connecting a position with the corresponding interpreter + /// value. + ValueMap values; + + /// A stack of blocks used as the failure destination for matcher nodes that + /// don't have an explicit failure path. + SmallVector failureBlockStack; + + /// A mapping between values defined in a pattern match, and the corresponding + /// positional value. + DenseMap valueToPosition; + + /// The set of operation values whose whose location will be used for newly + /// generated operations. + llvm::SetVector fusedLocOps; +}; +} // end anonymous namespace + +PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule) + : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), + rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {} + +void PatternLowering::lower(ModuleOp module) { + PredicateUniquer predicateUniquer; + PredicateBuilder predicateBuilder(predicateUniquer, module.getContext()); + + // Define top-level scope for the arguments to the matcher function. + ValueMapScope topLevelValueScope(values); + + // Insert the root operation, i.e. argument to the matcher, at the root + // position. + Block *matcherEntryBlock = matcherFunc.addEntryBlock(); + values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); + + // Generate a root matcher node from the provided PDL module. + std::unique_ptr root = MatcherNode::generateMatcherTree( + module, predicateBuilder, valueToPosition); + Block *firstMatcherBlock = generateMatcher(*root); + + // After generation, merged the first matched block into the entry. + matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), + firstMatcherBlock->getOperations()); + firstMatcherBlock->erase(); +} + +Block *PatternLowering::generateMatcher(MatcherNode &node) { + // Push a new scope for the values used by this matcher. + Block *block = matcherFunc.addBlock(); + ValueMapScope scope(values); + + // If this is the return node, simply insert the corresponding interpreter + // finalize. + if (isa(node)) { + builder.setInsertionPointToEnd(block); + builder.create(matcherFunc.getLoc()); + return block; + } + + // If this node contains a position, get the corresponding value for this + // block. + Position *position = node.getPosition(); + Value val = position ? getValueAt(block, position) : Value(); + + // Get the next block in the match sequence. + std::unique_ptr &failureNode = node.getFailureNode(); + Block *nextBlock; + if (failureNode) { + nextBlock = generateMatcher(*failureNode); + failureBlockStack.push_back(nextBlock); + } else { + assert(!failureBlockStack.empty() && "expected valid failure block"); + nextBlock = failureBlockStack.back(); + } + + // If this value corresponds to an operation, record that we are going to use + // its location as part of a fused location. + bool isOperationValue = val && val.getType().isa(); + if (isOperationValue) + fusedLocOps.insert(val); + + // Generate code for a boolean predicate node. + if (auto *boolNode = dyn_cast(&node)) { + auto *child = generateMatcher(*boolNode->getSuccessNode()); + generatePredicate(block, node.getQuestion(), boolNode->getAnswer(), val, + child, nextBlock); + + // Generate code for a switch node. + } else if (auto *switchNode = dyn_cast(&node)) { + // Collect the next blocks for all of the children and generate a switch. + llvm::MapVector children; + for (auto &it : switchNode->getChildren()) + children.insert({it.first, generateMatcher(*it.second)}); + generateSwitch(block, node.getQuestion(), val, nextBlock, + children.takeVector()); + + // Generate code for a success node. + } else if (auto *successNode = dyn_cast(&node)) { + // Generate a rewriter for the pattern this success node represents, and + // track any values used from the match region. + pdl::PatternOp pattern = successNode->getPattern(); + SmallVector usedMatchValues; + SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues); + + // Process any values used in the rewrite that are defined in the match. + std::vector mappedMatchValues; + mappedMatchValues.reserve(usedMatchValues.size()); + for (Position *position : usedMatchValues) + mappedMatchValues.push_back(getValueAt(block, position)); + + // Collect the set of operations generated by the rewriter. + SmallVector generatedOps; + for (auto op : pattern.getRewriter().body().getOps()) + generatedOps.push_back(*op.name()); + ArrayAttr generatedOpsAttr; + if (!generatedOps.empty()) + generatedOpsAttr = builder.getStrArrayAttr(generatedOps); + + // Grab the root kind if present. + StringAttr rootKindAttr; + if (Optional rootKind = pattern.getRootKind()) + rootKindAttr = builder.getStringAttr(*rootKind); + + builder.setInsertionPointToEnd(block); + builder.create( + pattern.getLoc(), mappedMatchValues, fusedLocOps.getArrayRef(), + rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(), + nextBlock); + } + + if (failureNode) + failureBlockStack.pop_back(); + if (isOperationValue) + fusedLocOps.remove(val); + return block; +} + +Value PatternLowering::getValueAt(Block *cur, Position *pos) { + if (Value val = values.lookup(pos)) + return val; + + // Get the value for the parent position. + Value parentVal = getValueAt(cur, pos->getParent()); + + // TODO: Use a location from the position. + Location loc = parentVal.getLoc(); + builder.setInsertionPointToEnd(cur); + Value value; + switch (pos->getKind()) { + case Predicates::OperationPos: + value = builder.create( + loc, builder.getType(), parentVal); + break; + case Predicates::OperandPos: { + auto *operandPos = cast(pos); + value = builder.create( + loc, builder.getType(), parentVal, + operandPos->getOperandNumber()); + break; + } + case Predicates::AttributePos: { + auto *attrPos = cast(pos); + value = builder.create( + loc, builder.getType(), parentVal, + attrPos->getName().strref()); + break; + } + case Predicates::TypePos: { + if (parentVal.getType().isa()) + value = builder.create(loc, parentVal); + else + value = builder.create(loc, parentVal); + break; + } + case Predicates::ResultPos: { + auto *resPos = cast(pos); + value = builder.create( + loc, builder.getType(), parentVal, + resPos->getResultNumber()); + break; + } + default: + llvm_unreachable("Generating unknown Position getter"); + break; + } + values.insert(pos, value); + return value; +} + +void PatternLowering::generatePredicate(Block *currentBlock, + Qualifier *question, Qualifier *answer, + Value val, Block *trueDest, + Block *falseDest) { + builder.setInsertionPointToEnd(currentBlock); + Location loc = val.getLoc(); + switch (question->getKind()) { + case Predicates::IsNotNullQuestion: + builder.create(loc, val, trueDest, falseDest); + break; + case Predicates::OperationNameQuestion: { + auto *opNameAnswer = cast(answer); + builder.create( + loc, val, opNameAnswer->getValue().getStringRef(), trueDest, falseDest); + break; + } + case Predicates::TypeQuestion: { + auto *ans = cast(answer); + builder.create( + loc, val, TypeAttr::get(ans->getValue()), trueDest, falseDest); + break; + } + case Predicates::AttributeQuestion: { + auto *ans = cast(answer); + builder.create(loc, val, ans->getValue(), + trueDest, falseDest); + break; + } + case Predicates::OperandCountQuestion: { + auto *unsignedAnswer = cast(answer); + builder.create( + loc, val, unsignedAnswer->getValue(), trueDest, falseDest); + break; + } + case Predicates::ResultCountQuestion: { + auto *unsignedAnswer = cast(answer); + builder.create( + loc, val, unsignedAnswer->getValue(), trueDest, falseDest); + break; + } + case Predicates::EqualToQuestion: { + auto *equalToQuestion = cast(question); + builder.create( + loc, val, getValueAt(currentBlock, equalToQuestion->getValue()), + trueDest, falseDest); + break; + } + case Predicates::ConstraintQuestion: { + auto *cstQuestion = cast(question); + SmallVector args; + for (Position *position : std::get<1>(cstQuestion->getValue())) + args.push_back(getValueAt(currentBlock, position)); + builder.create( + loc, std::get<0>(cstQuestion->getValue()), args, + std::get<2>(cstQuestion->getValue()).cast(), trueDest, + falseDest); + break; + } + default: + llvm_unreachable("Generating unknown Predicate operation"); + } +} + +template +static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, + ArrayRef> dests) { + std::vector values; + std::vector blocks; + values.reserve(dests.size()); + blocks.reserve(dests.size()); + for (auto &it : dests) { + blocks.push_back(it.second); + values.push_back(cast(it.first)->getValue()); + } + builder.create(val.getLoc(), val, values, defaultDest, blocks); +} + +void PatternLowering::generateSwitch( + Block *currentBlock, Qualifier *question, Value val, Block *defaultDest, + ArrayRef> dests) { + builder.setInsertionPointToEnd(currentBlock); + switch (question->getKind()) { + case Predicates::OperandCountQuestion: + return createSwitchOp(val, defaultDest, builder, dests); + case Predicates::ResultCountQuestion: + return createSwitchOp(val, defaultDest, builder, dests); + case Predicates::OperationNameQuestion: + return createSwitchOp(val, defaultDest, builder, + dests); + case Predicates::TypeQuestion: + return createSwitchOp( + val, defaultDest, builder, dests); + case Predicates::AttributeQuestion: + return createSwitchOp( + val, defaultDest, builder, dests); + default: + llvm_unreachable("Generating unknown switch predicate."); + } +} + +SymbolRefAttr PatternLowering::generateRewriter( + pdl::PatternOp pattern, SmallVectorImpl &usedMatchValues) { + FuncOp rewriterFunc = + FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter", + builder.getFunctionType(llvm::None, llvm::None)); + rewriterSymbolTable.insert(rewriterFunc); + + // Generate the rewriter function body. + builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock()); + + // Map an input operand of the pattern to a generated interpreter value. + DenseMap rewriteValues; + auto mapRewriteValue = [&](Value oldValue) { + Value &newValue = rewriteValues[oldValue]; + if (newValue) + return newValue; + + // Prefer materializing constants directly when possible. + Operation *oldOp = oldValue.getDefiningOp(); + if (pdl::AttributeOp attrOp = dyn_cast(oldOp)) { + if (Attribute value = attrOp.valueAttr()) { + return newValue = builder.create( + attrOp.getLoc(), value); + } + } else if (pdl::TypeOp typeOp = dyn_cast(oldOp)) { + if (TypeAttr type = typeOp.typeAttr()) { + return newValue = builder.create( + typeOp.getLoc(), type); + } + } + + // Otherwise, add this as an input to the rewriter. + Position *inputPos = valueToPosition.lookup(oldValue); + assert(inputPos && "expected value to be a pattern input"); + usedMatchValues.push_back(inputPos); + return newValue = rewriterFunc.front().addArgument(oldValue.getType()); + }; + + // If this is a custom rewriter, simply dispatch to the registered rewrite + // method. + pdl::RewriteOp rewriter = pattern.getRewriter(); + if (StringAttr rewriteName = rewriter.nameAttr()) { + Value root = mapRewriteValue(rewriter.root()); + SmallVector args = llvm::to_vector<4>( + llvm::map_range(rewriter.externalArgs(), mapRewriteValue)); + builder.create( + rewriter.getLoc(), rewriteName, root, args, + rewriter.externalConstParamsAttr()); + } else { + // Otherwise this is a dag rewriter defined using PDL operations. + for (Operation &rewriteOp : *rewriter.getBody()) { + llvm::TypeSwitch(&rewriteOp) + .Case([&](auto op) { + this->generateRewriter(op, rewriteValues, mapRewriteValue); + }); + } + } + + // Update the signature of the rewrite function. + rewriterFunc.setType(builder.getFunctionType( + llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), + /*results=*/llvm::None)); + + builder.create(rewriter.getLoc()); + return builder.getSymbolRefAttr( + pdl_interp::PDLInterpDialect::getRewriterModuleName(), + builder.getSymbolRefAttr(rewriterFunc)); +} + +void PatternLowering::generateRewriter( + pdl::AttributeOp attrOp, DenseMap &rewriteValues, + function_ref mapRewriteValue) { + Value newAttr = builder.create( + attrOp.getLoc(), attrOp.valueAttr()); + rewriteValues.try_emplace(attrOp, newAttr); +} + +void PatternLowering::generateRewriter( + pdl::EraseOp eraseOp, DenseMap &rewriteValues, + function_ref mapRewriteValue) { + builder.create(eraseOp.getLoc(), + mapRewriteValue(eraseOp.operation())); +} + +void PatternLowering::generateRewriter( + pdl::OperationOp operationOp, DenseMap &rewriteValues, + function_ref mapRewriteValue) { + SmallVector operands; + for (Value operand : operationOp.operands()) + operands.push_back(mapRewriteValue(operand)); + + SmallVector attributes; + for (Value attr : operationOp.attributes()) + attributes.push_back(mapRewriteValue(attr)); + + SmallVector types; + generateOperationResultTypeRewriter(operationOp, types, rewriteValues, + mapRewriteValue); + + // Create the new operation. + Location loc = operationOp.getLoc(); + Value createdOp = builder.create( + loc, *operationOp.name(), types, operands, attributes, + operationOp.attributeNames()); + rewriteValues.try_emplace(operationOp.op(), createdOp); + + // Make all of the new operation results available. + OperandRange resultTypes = operationOp.types(); + for (auto it : llvm::enumerate(operationOp.results())) { + Value getResultVal = builder.create( + loc, builder.getType(), createdOp, it.index()); + rewriteValues.try_emplace(it.value(), getResultVal); + + // If any of the types have not been resolved, make those available as well. + Value &type = rewriteValues[resultTypes[it.index()]]; + if (!type) + type = builder.create(loc, getResultVal); + } +} + +void PatternLowering::generateRewriter( + pdl::CreateNativeOp createNativeOp, DenseMap &rewriteValues, + function_ref mapRewriteValue) { + SmallVector arguments; + for (Value argument : createNativeOp.args()) + arguments.push_back(mapRewriteValue(argument)); + Value result = builder.create( + createNativeOp.getLoc(), createNativeOp.result().getType(), + createNativeOp.nameAttr(), arguments, createNativeOp.constParamsAttr()); + rewriteValues.try_emplace(createNativeOp, result); +} + +void PatternLowering::generateRewriter( + pdl::ReplaceOp replaceOp, DenseMap &rewriteValues, + function_ref mapRewriteValue) { + // If the replacement was another operation, get its results. `pdl` allows + // for using an operation for simplicitly, but the interpreter isn't as + // user facing. + ValueRange origOperands; + if (Value replOp = replaceOp.replOperation()) + origOperands = cast(replOp.getDefiningOp()).results(); + else + origOperands = replaceOp.replValues(); + + // If there are no replacement values, just create an erase instead. + if (origOperands.empty()) { + builder.create(replaceOp.getLoc(), + mapRewriteValue(replaceOp.operation())); + return; + } + + SmallVector replOperands; + for (Value operand : origOperands) + replOperands.push_back(mapRewriteValue(operand)); + builder.create( + replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands); +} + +void PatternLowering::generateRewriter( + pdl::TypeOp typeOp, DenseMap &rewriteValues, + function_ref mapRewriteValue) { + // If the type isn't constant, the users (e.g. OperationOp) will resolve this + // type. + if (TypeAttr typeAttr = typeOp.typeAttr()) { + Value newType = + builder.create(typeOp.getLoc(), typeAttr); + rewriteValues.try_emplace(typeOp, newType); + } +} + +void PatternLowering::generateOperationResultTypeRewriter( + pdl::OperationOp op, SmallVectorImpl &types, + DenseMap &rewriteValues, + function_ref mapRewriteValue) { + // Functor that returns if the given use can be used to infer a type. + Block *rewriterBlock = op.getOperation()->getBlock(); + auto getReplacedOperationFrom = [&](OpOperand &use) -> Operation * { + // Check that the use corresponds to a ReplaceOp and that it is the + // replacement value, not the operation being replaced. + pdl::ReplaceOp replOpUser = dyn_cast(use.getOwner()); + if (!replOpUser || use.getOperandNumber() == 0) + return nullptr; + // Make sure the replaced operation was defined before this one. + Operation *replacedOp = replOpUser.operation().getDefiningOp(); + if (replacedOp->getBlock() != rewriterBlock || + replacedOp->isBeforeInBlock(op)) + return replacedOp; + return nullptr; + }; + + // If non-None/non-Null, this is an operation that is replaced by `op`. + // If Null, there is no full replacement operation for `op`. + // If None, a replacement operation hasn't been searched for. + Optional fullReplacedOperation; + bool hasTypeInference = op.hasTypeInference(); + auto resultTypeValues = op.types(); + types.reserve(resultTypeValues.size()); + for (auto it : llvm::enumerate(op.results())) { + Value result = it.value(), resultType = resultTypeValues[it.index()]; + + // Check for an already translated value. + if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { + types.push_back(existingRewriteValue); + continue; + } + + // Check for an input from the matcher. + if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { + types.push_back(mapRewriteValue(resultType)); + continue; + } + + // Check if the operation has type inference support. + if (hasTypeInference) { + types.push_back(builder.create(op.getLoc())); + continue; + } + + // Look for an operation that was replaced by `op`. The result type will be + // inferred from the result that was replaced. There is guaranteed to be a + // replacement for either the op, or this specific result. Note that this is + // guaranteed by the verifier of `pdl::OperationOp`. + Operation *replacedOp = nullptr; + if (!fullReplacedOperation.hasValue()) { + for (OpOperand &use : op.op().getUses()) + if ((replacedOp = getReplacedOperationFrom(use))) + break; + fullReplacedOperation = replacedOp; + } else { + replacedOp = fullReplacedOperation.getValue(); + } + // Infer from the result, as there was no fully replaced op. + if (!replacedOp) { + for (OpOperand &use : result.getUses()) + if ((replacedOp = getReplacedOperationFrom(use))) + break; + assert(replacedOp && "expected replaced op to infer a result type from"); + } + + auto replOpOp = cast(replacedOp); + types.push_back(mapRewriteValue(replOpOp.types()[it.index()])); + } +} + +//===----------------------------------------------------------------------===// +// Conversion Pass +//===----------------------------------------------------------------------===// + +namespace { +struct PDLToPDLInterpPass + : public ConvertPDLToPDLInterpBase { + void runOnOperation() final; +}; +} // namespace + +/// Convert the given module containing PDL pattern operations into a PDL +/// Interpreter operations. +void PDLToPDLInterpPass::runOnOperation() { + ModuleOp module = getOperation(); + + // Create the main matcher function This function contains all of the match + // related functionality from patterns in the module. + OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); + FuncOp matcherFunc = builder.create( + module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), + builder.getFunctionType(builder.getType(), + /*results=*/llvm::None), + /*attrs=*/llvm::None); + + // Create a nested module to hold the functions invoked for rewriting the IR + // after a successful match. + ModuleOp rewriterModule = builder.create( + module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); + + // Generate the code for the patterns within the module. + PatternLowering generator(matcherFunc, rewriterModule); + generator.lower(module); + + // After generation, delete all of the pattern operations. + for (pdl::PatternOp pattern : + llvm::make_early_inc_range(module.getOps())) + pattern.erase(); +} + +std::unique_ptr> mlir::createPDLToPDLInterpPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -0,0 +1,530 @@ +//===- Predicate.h - Pattern predicates -------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions for "predicates" used when converting PDL into +// a matcher tree. Predicates are composed of three different parts: +// +// * Positions +// - A position refers to a specific location on the input DAG, i.e. an +// existing MLIR entity being matched. These can be attributes, operands, +// operations, results, and types. Each position also defines a relation to +// its parent. For example, the operand `[0] -> 1` has a parent operation +// position `[0]`. The attribute `[0, 1] -> "myAttr"` has parent operation +// position of `[0, 1]`. The operation `[0, 1]` has a parent operand edge +// `[0] -> 1` (i.e. it is the defining op of operand 1). The only position +// without a parent is `[0]`, which refers to the root operation. +// * Questions +// - A question refers to a query on a specific positional value. For +// example, an operation name question checks the name of an operation +// position. +// * Answers +// - An answer is the expected result of a question. For example, when +// matching an operation with the name "foo.op". The question would be an +// operation name question, with an expected answer of "foo.op". +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ +#define MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace pdl_to_pdl_interp { +namespace Predicates { +/// An enumeration of the kinds of predicates. +enum Kind : unsigned { + /// Positions, ordered by decreasing priority. + OperationPos, + OperandPos, + AttributePos, + ResultPos, + TypePos, + + // Questions, ordered by dependency and decreasing priority. + IsNotNullQuestion, + OperationNameQuestion, + TypeQuestion, + AttributeQuestion, + OperandCountQuestion, + ResultCountQuestion, + EqualToQuestion, + ConstraintQuestion, + + // Answers. + AttributeAnswer, + TrueAnswer, + OperationNameAnswer, + TypeAnswer, + UnsignedAnswer, +}; +} // end namespace Predicates + +/// Base class for all predicates, used to allow efficient pointer comparison. +template +class PredicateBase : public BaseT { +public: + using KeyTy = Key; + using Base = PredicateBase; + + template + explicit PredicateBase(KeyT &&key) + : BaseT(Kind), key(std::forward(key)) {} + + /// Get an instance of this position. + template + static ConcreteT *get(StorageUniquer &uniquer, Args &&... args) { + return uniquer.get(/*initFn=*/{}, std::forward(args)...); + } + + /// Construct an instance with the given storage allocator. + template + static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc, + KeyT &&key) { + return new (alloc.allocate()) ConcreteT(std::forward(key)); + } + + /// Utility methods required by the storage allocator. + bool operator==(const KeyTy &key) const { return this->key == key; } + static bool classof(const BaseT *pred) { return pred->getKind() == Kind; } + + /// Return the key value of this predicate. + const KeyTy &getValue() const { return key; } + +protected: + KeyTy key; +}; + +/// Base storage for simple predicates that only unique with the kind. +template +class PredicateBase : public BaseT { +public: + using Base = PredicateBase; + + explicit PredicateBase() : BaseT(Kind) {} + + static ConcreteT *get(StorageUniquer &uniquer) { + return uniquer.get(); + } + static bool classof(const BaseT *pred) { return pred->getKind() == Kind; } +}; + +//===----------------------------------------------------------------------===// +// Positions +//===----------------------------------------------------------------------===// + +struct OperationPosition; + +/// A position describes a value on the input IR on which a predicate may be +/// applied, such as an operation or attribute. This enables re-use between +/// predicates, and assists generating bytecode and memory management. +/// +/// Operation positions form the base of other positions, which are formed +/// relative to a parent operation, e.g. OperandPosition<[0] -> 1>. Operations +/// are indexed by child index: [0, 1, 2] refers to the 3rd child of the 2nd +/// child of the root operation. +/// +/// Positions are linked to their parent position, which describes how to obtain +/// a positional value. As a concrete example, getting OperationPosition<[0, 1]> +/// would be `root->getOperand(1)->getDefiningOp()`, so its parent is +/// OperandPosition<[0] -> 1>, whose parent is OperationPosition<[0]>. +class Position : public StorageUniquer::BaseStorage { +public: + explicit Position(Predicates::Kind kind) : kind(kind) {} + virtual ~Position(); + + /// Returns the base node position. This is an array of indices. + virtual ArrayRef getIndex() const = 0; + + /// Returns the parent position. The root operation position has no parent. + Position *getParent() const { return parent; } + + /// Returns the kind of this position. + Predicates::Kind getKind() const { return kind; } + +protected: + /// Link to the parent position. + Position *parent = nullptr; + +private: + /// The kind of this position. + Predicates::Kind kind; +}; + +//===----------------------------------------------------------------------===// +// AttributePosition + +/// A position describing an attribute of an operation. +struct AttributePosition + : public PredicateBase, + Predicates::AttributePos> { + explicit AttributePosition(const KeyTy &key); + + /// Returns the index of this position. + ArrayRef getIndex() const final { return parent->getIndex(); } + + /// Returns the attribute name of this position. + Identifier getName() const { return key.second; } +}; + +//===----------------------------------------------------------------------===// +// OperandPosition + +/// A position describing an operand of an operation. +struct OperandPosition + : public PredicateBase, + Predicates::OperandPos> { + explicit OperandPosition(const KeyTy &key); + + /// Returns the index of this position. + ArrayRef getIndex() const final { return parent->getIndex(); } + + /// Returns the operand number of this position. + unsigned getOperandNumber() const { return key.second; } +}; + +//===----------------------------------------------------------------------===// +// OperationPosition + +/// An operation position describes an operation node in the IR. Other position +/// kinds are formed with respect to an operation position. +struct OperationPosition + : public PredicateBase, + Predicates::OperationPos> { + using Base::Base; + + /// Gets the root position, which is always [0]. + static OperationPosition *getRoot(StorageUniquer &uniquer) { + return get(uniquer, ArrayRef(0)); + } + /// Gets a node position for the given index. + static OperationPosition *get(StorageUniquer &uniquer, + ArrayRef index); + + /// Constructs an instance with the given storage allocator. + static OperationPosition *construct(StorageUniquer::StorageAllocator &alloc, + ArrayRef key) { + return Base::construct(alloc, alloc.copyInto(key)); + } + + /// Returns the index of this position. + ArrayRef getIndex() const final { return key; } + + /// Returns if this operation position corresponds to the root. + bool isRoot() const { return key.size() == 1 && key[0] == 0; } +}; + +//===----------------------------------------------------------------------===// +// ResultPosition + +/// A position describing a result of an operation. +struct ResultPosition + : public PredicateBase, + Predicates::ResultPos> { + explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; } + + /// Returns the index of this position. + ArrayRef getIndex() const final { return key.first->getIndex(); } + + /// Returns the result number of this position. + unsigned getResultNumber() const { return key.second; } +}; + +//===----------------------------------------------------------------------===// +// TypePosition + +/// A position describing the result type of an entity, i.e. an Attribute, +/// Operand, Result, etc. +struct TypePosition : public PredicateBase { + explicit TypePosition(const KeyTy &key) : Base(key) { + assert((isa(key) || isa(key) || + isa(key)) && + "expected parent to be an attribute, operand, or result"); + parent = key; + } + + /// Returns the index of this position. + ArrayRef getIndex() const final { return key->getIndex(); } +}; + +//===----------------------------------------------------------------------===// +// Qualifiers +//===----------------------------------------------------------------------===// + +/// An ordinal predicate consists of a "Question" and a set of acceptable +/// "Answers" (later converted to ordinal values). A predicate will query some +/// property of a positional value and decide what to do based on the result. +/// +/// This makes top-level predicate representations ordinal (SwitchOp). Later, +/// predicates that end up with only one acceptable answer (including all +/// boolean kinds) will be converted to boolean predicates (PredicateOp) in the +/// matcher. +/// +/// For simplicity, both are represented as "qualifiers", with a base kind and +/// perhaps additional properties. For example, all OperationName predicates ask +/// the same question, but GenericConstraint predicates may ask different ones. +class Qualifier : public StorageUniquer::BaseStorage { +public: + explicit Qualifier(Predicates::Kind kind) : kind(kind) {} + + /// Returns the kind of this qualifier. + Predicates::Kind getKind() const { return kind; } + +private: + /// The kind of this position. + Predicates::Kind kind; +}; + +//===----------------------------------------------------------------------===// +// Answers + +/// An Answer representing an `Attribute` value. +struct AttributeAnswer + : public PredicateBase { + using Base::Base; +}; + +/// An Answer representing an `OperationName` value. +struct OperationNameAnswer + : public PredicateBase { + using Base::Base; +}; + +/// An Answer representing a boolean `true` value. +struct TrueAnswer + : PredicateBase { + using Base::Base; +}; + +/// An Answer representing a `Type` value. +struct TypeAnswer : public PredicateBase { + using Base::Base; +}; + +/// An Answer representing an unsigned value. +struct UnsignedAnswer + : public PredicateBase { + using Base::Base; +}; + +//===----------------------------------------------------------------------===// +// Questions + +/// Compare an `Attribute` to a constant value. +struct AttributeQuestion + : public PredicateBase {}; + +/// Apply a parameterized constraint to multiple position values. +struct ConstraintQuestion + : public PredicateBase< + ConstraintQuestion, Qualifier, + std::tuple, Attribute>, + Predicates::ConstraintQuestion> { + using Base::Base; + + /// Construct an instance with the given storage allocator. + static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, + KeyTy key) { + return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)), + alloc.copyInto(std::get<1>(key)), + std::get<2>(key)}); + } +}; + +/// Compare the equality of two values. +struct EqualToQuestion + : public PredicateBase { + using Base::Base; +}; + +/// Compare a positional value with null, i.e. check if it exists. +struct IsNotNullQuestion + : public PredicateBase {}; + +/// Compare the number of operands of an operation with a known value. +struct OperandCountQuestion + : public PredicateBase {}; + +/// Compare the name of an operation with a known value. +struct OperationNameQuestion + : public PredicateBase {}; + +/// Compare the number of results of an operation with a known value. +struct ResultCountQuestion + : public PredicateBase {}; + +/// Compare the type of an attribute or value with a known type. +struct TypeQuestion : public PredicateBase {}; + +//===----------------------------------------------------------------------===// +// PredicateUniquer +//===----------------------------------------------------------------------===// + +/// This class provides a storage uniquer that is used to allocate predicate +/// instances. +class PredicateUniquer : public StorageUniquer { +public: + PredicateUniquer() { + // Register the types of Positions with the uniquer. + registerParametricStorageType(); + registerParametricStorageType(); + registerParametricStorageType(); + registerParametricStorageType(); + registerParametricStorageType(); + + // Register the types of Questions with the uniquer. + registerParametricStorageType(); + registerParametricStorageType(); + registerParametricStorageType(); + registerParametricStorageType(); + registerSingletonStorageType(); + + // Register the types of Answers with the uniquer. + registerParametricStorageType(); + registerParametricStorageType(); + registerSingletonStorageType(); + registerSingletonStorageType(); + registerSingletonStorageType(); + registerSingletonStorageType(); + registerSingletonStorageType(); + registerSingletonStorageType(); + } +}; + +//===----------------------------------------------------------------------===// +// PredicateBuilder +//===----------------------------------------------------------------------===// + +/// This class provides utilties for constructing predicates. +class PredicateBuilder { +public: + PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx) + : uniquer(uniquer), ctx(ctx) {} + + //===--------------------------------------------------------------------===// + // Positions + //===--------------------------------------------------------------------===// + + /// Returns the root operation position. + Position *getRoot() { return OperationPosition::getRoot(uniquer); } + + /// Returns the parent position defining the value held by the given operand. + Position *getParent(OperandPosition *p) { + std::vector index = p->getIndex(); + index.push_back(p->getOperandNumber()); + return OperationPosition::get(uniquer, index); + } + + /// Returns an attribute position for an attribute of the given operation. + Position *getAttribute(OperationPosition *p, StringRef name) { + return AttributePosition::get(uniquer, p, Identifier::get(name, ctx)); + } + + /// Returns an operand position for an operand of the given operation. + Position *getOperand(OperationPosition *p, unsigned operand) { + return OperandPosition::get(uniquer, p, operand); + } + + /// Returns a result position for a result of the given operation. + Position *getResult(OperationPosition *p, unsigned result) { + return ResultPosition::get(uniquer, p, result); + } + + /// Returns a type position for the given entity. + Position *getType(Position *p) { return TypePosition::get(uniquer, p); } + + //===--------------------------------------------------------------------===// + // Qualifiers + //===--------------------------------------------------------------------===// + + /// An ordinal predicate consists of a "Question" and a set of acceptable + /// "Answers" (later converted to ordinal values). A predicate will query some + /// property of a positional value and decide what to do based on the result. + using Predicate = std::pair; + + /// Create a predicate comparing an attribute to a known value. + Predicate getAttributeConstraint(Attribute attr) { + return {AttributeQuestion::get(uniquer), + AttributeAnswer::get(uniquer, attr)}; + } + + /// Create a predicate comparing two values. + Predicate getEqualTo(Position *pos) { + return {EqualToQuestion::get(uniquer, pos), TrueAnswer::get(uniquer)}; + } + + /// Create a predicate that applies a generic constraint. + Predicate getConstraint(StringRef name, ArrayRef pos, + Attribute params) { + return { + ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, params)), + TrueAnswer::get(uniquer)}; + } + + /// Create a predicate comparing a value with null. + Predicate getIsNotNull() { + return {IsNotNullQuestion::get(uniquer), TrueAnswer::get(uniquer)}; + } + + /// Create a predicate comparing the number of operands of an operation to a + /// known value. + Predicate getOperandCount(unsigned count) { + return {OperandCountQuestion::get(uniquer), + UnsignedAnswer::get(uniquer, count)}; + } + + /// Create a predicate comparing the name of an operation to a known value. + Predicate getOperationName(StringRef name) { + return {OperationNameQuestion::get(uniquer), + OperationNameAnswer::get(uniquer, OperationName(name, ctx))}; + } + + /// Create a predicate comparing the number of results of an operation to a + /// known value. + Predicate getResultCount(unsigned count) { + return {ResultCountQuestion::get(uniquer), + UnsignedAnswer::get(uniquer, count)}; + } + + /// Create a predicate comparing the type of an attribute or value to a known + /// type. + Predicate getTypeConstraint(Type type) { + return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, type)}; + } + +private: + /// The uniquer used when allocating predicate nodes. + PredicateUniquer &uniquer; + + /// The current MLIR context. + MLIRContext *ctx; +}; + +} // end namespace pdl_to_pdl_interp +} // end namespace mlir + +#endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp @@ -0,0 +1,49 @@ +//===- Predicate.cpp - Pattern predicates -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Predicate.h" + +using namespace mlir; +using namespace mlir::pdl_to_pdl_interp; + +//===----------------------------------------------------------------------===// +// Positions +//===----------------------------------------------------------------------===// + +Position::~Position() {} + +//===----------------------------------------------------------------------===// +// AttributePosition + +AttributePosition::AttributePosition(const KeyTy &key) : Base(key) { + parent = key.first; +} + +//===----------------------------------------------------------------------===// +// OperandPosition + +OperandPosition::OperandPosition(const KeyTy &key) : Base(key) { + parent = key.first; +} + +//===----------------------------------------------------------------------===// +// OperationPosition + +OperationPosition *OperationPosition::get(StorageUniquer &uniquer, + ArrayRef index) { + assert(!index.empty() && "expected at least two indices"); + + // Set the parent position if this isn't the root. + Position *parent = nullptr; + if (index.size() > 1) { + auto *node = OperationPosition::get(uniquer, index.drop_back()); + parent = OperandPosition::get(uniquer, std::make_pair(node, index.back())); + } + return uniquer.get( + [parent](OperationPosition *node) { node->parent = parent; }, index); +} diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h @@ -0,0 +1,200 @@ +//===- PredicateTree.h - Predicate tree node definitions --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions for nodes of a tree structure for representing +// the general control flow within a pattern match. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_ +#define MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_ + +#include "Predicate.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "llvm/ADT/MapVector.h" + +namespace mlir { +namespace pdl_to_pdl_interp { + +class MatcherNode; + +/// A PositionalPredicate is a predicate that is associated with a specific +/// positional value. +struct PositionalPredicate { + PositionalPredicate(Position *pos, + const PredicateBuilder::Predicate &predicate) + : position(pos), question(predicate.first), answer(predicate.second) {} + + /// The position the predicate is applied to. + Position *position; + + /// The question that the predicate applies. + Qualifier *question; + + /// The expected answer of the predicate. + Qualifier *answer; +}; + +//===----------------------------------------------------------------------===// +// MatcherNode +//===----------------------------------------------------------------------===// + +/// This class represents the base of a predicate matcher node. +class MatcherNode { +public: + virtual ~MatcherNode() = default; + + /// Given a module containing PDL pattern operations, generate a matcher tree + /// using the patterns within the given module and return the root matcher + /// node. `valueToPosition` is a map that is populated with the original + /// pdl values and their corresponding positions in the matcher tree. + static std::unique_ptr + generateMatcherTree(ModuleOp module, PredicateBuilder &builder, + DenseMap &valueToPosition); + + /// Returns the position on which the question predicate should be checked. + Position *getPosition() const { return position; } + + /// Returns the predicate checked on this node. + Qualifier *getQuestion() const { return question; } + + /// Returns the node that should be visited if this, or a subsequent node + /// fails. + std::unique_ptr &getFailureNode() { return failureNode; } + + /// Sets the node that should be visited if this, or a subsequent node fails. + void setFailureNode(std::unique_ptr node) { + failureNode = std::move(node); + } + + /// Returns the unique type ID of this matcher instance. This should not be + /// used directly, and is provided to support type casting. + TypeID getMatcherTypeID() const { return matcherTypeID; } + +protected: + MatcherNode(TypeID matcherTypeID, Position *position = nullptr, + Qualifier *question = nullptr, + std::unique_ptr failureNode = nullptr); + +private: + /// The position on which the predicate should be checked. + Position *position; + + /// The predicate that is checked on the given position. + Qualifier *question; + + /// The node to visit if this node fails. + std::unique_ptr failureNode; + + /// An owning store for the failure node if it is owned by this node. + std::unique_ptr failureNodeStorage; + + /// A unique identifier for the derived matcher node, used for type casting. + TypeID matcherTypeID; +}; + +//===----------------------------------------------------------------------===// +// BoolNode + +/// A BoolNode denotes a question with a boolean-like result. These nodes branch +/// to a single node on a successful result, otherwise defaulting to the failure +/// node. +struct BoolNode : public MatcherNode { + BoolNode(Position *position, Qualifier *question, Qualifier *answer, + std::unique_ptr successNode, + std::unique_ptr failureNode = nullptr); + + /// Returns if the given matcher node is an instance of this class, used to + /// support type casting. + static bool classof(const MatcherNode *node) { + return node->getMatcherTypeID() == TypeID::get(); + } + + /// Returns the expected answer of this boolean node. + Qualifier *getAnswer() const { return answer; } + + /// Returns the node that should be visited on success. + std::unique_ptr &getSuccessNode() { return successNode; } + +private: + /// The expected answer of this boolean node. + Qualifier *answer; + + /// The next node if this node succeeds. Otherwise, go to the failure node. + std::unique_ptr successNode; +}; + +//===----------------------------------------------------------------------===// +// ExitNode + +/// An ExitNode is a special sentinel node that denotes the end of matcher. +struct ExitNode : public MatcherNode { + ExitNode() : MatcherNode(TypeID::get()) {} + + /// Returns if the given matcher node is an instance of this class, used to + /// support type casting. + static bool classof(const MatcherNode *node) { + return node->getMatcherTypeID() == TypeID::get(); + } +}; + +//===----------------------------------------------------------------------===// +// SuccessNode + +/// A SuccessNode denotes that a given high level pattern has successfully been +/// matched. This does not terminate the matcher, as there may be multiple +/// successful matches. +struct SuccessNode : public MatcherNode { + explicit SuccessNode(pdl::PatternOp pattern, + std::unique_ptr failureNode); + + /// Returns if the given matcher node is an instance of this class, used to + /// support type casting. + static bool classof(const MatcherNode *node) { + return node->getMatcherTypeID() == TypeID::get(); + } + + /// Return the high level pattern operation that is matched with this node. + pdl::PatternOp getPattern() const { return pattern; } + +private: + /// The high level pattern operation that was successfully matched with this + /// node. + pdl::PatternOp pattern; +}; + +//===----------------------------------------------------------------------===// +// SwitchNode + +/// A SwitchNode denotes a question with multiple potential results. These nodes +/// branch to a specific node based on the result of the question. +struct SwitchNode : public MatcherNode { + SwitchNode(Position *position, Qualifier *question); + + /// Returns if the given matcher node is an instance of this class, used to + /// support type casting. + static bool classof(const MatcherNode *node) { + return node->getMatcherTypeID() == TypeID::get(); + } + + /// Returns the children of this switch node. The children are contained + /// within a mapping between the various case answers to destination matcher + /// nodes. + using ChildMapT = llvm::MapVector>; + ChildMapT &getChildren() { return children; } + +private: + /// Switch predicate "answers" select the child. Answers that are not found + /// default to the failure node. + ChildMapT children; +}; + +} // end namespace pdl_to_pdl_interp +} // end namespace mlir + +#endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_ diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -0,0 +1,462 @@ +//===- PredicateTree.cpp - Predicate tree merging ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PredicateTree.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/IR/Module.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" + +using namespace mlir; +using namespace mlir::pdl_to_pdl_interp; + +//===----------------------------------------------------------------------===// +// Predicate List Building +//===----------------------------------------------------------------------===// + +/// Compares the depths of two positions. +static bool comparePosDepth(Position *lhs, Position *rhs) { + return lhs->getIndex().size() < rhs->getIndex().size(); +} + +/// Collect the tree predicates anchored at the given value. +static void getTreePredicates(std::vector &predList, + Value val, PredicateBuilder &builder, + DenseMap &inputs, + Position *pos) { + // Make sure this input value is accessible to the rewrite. + auto it = inputs.try_emplace(val, pos); + + // If this is an input value that has been visited in the tree, add a + // constraint to ensure that both instances refer to the same value. + if (!it.second && + isa(val.getDefiningOp())) { + auto minMaxPositions = std::minmax(pos, it.first->second, comparePosDepth); + predList.emplace_back(minMaxPositions.second, + builder.getEqualTo(minMaxPositions.first)); + return; + } + + // Check for a per-position predicate to apply. + switch (pos->getKind()) { + case Predicates::AttributePos: { + assert(val.getType().isa() && + "expected attribute type"); + pdl::AttributeOp attr = cast(val.getDefiningOp()); + predList.emplace_back(pos, builder.getIsNotNull()); + + // If the attribute has a type, add a type constraint. + if (Value type = attr.type()) { + getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); + + // Check for a constant value of the attribute. + } else if (Optional value = attr.value()) { + predList.emplace_back(pos, builder.getAttributeConstraint(*value)); + } + break; + } + case Predicates::OperandPos: { + assert(val.getType().isa() && "expected value type"); + + // Prevent traversal into a null value. + predList.emplace_back(pos, builder.getIsNotNull()); + + // If this is a typed input, add a type constraint. + if (auto in = val.getDefiningOp()) { + if (Value type = in.type()) { + getTreePredicates(predList, type, builder, inputs, + builder.getType(pos)); + } + + // Otherwise, recurse into the parent node. + } else if (auto parentOp = val.getDefiningOp()) { + getTreePredicates(predList, parentOp.op(), builder, inputs, + builder.getParent(cast(pos))); + } + break; + } + case Predicates::OperationPos: { + assert(val.getType().isa() && "expected operation"); + pdl::OperationOp op = cast(val.getDefiningOp()); + OperationPosition *opPos = cast(pos); + + // Ensure getDefiningOp returns a non-null operation. + if (!opPos->isRoot()) + predList.emplace_back(pos, builder.getIsNotNull()); + + // Check that this is the correct root operation. + if (Optional opName = op.name()) + predList.emplace_back(pos, builder.getOperationName(*opName)); + + // Check that the operation has the proper number of operands and results. + OperandRange operands = op.operands(); + ResultRange results = op.results(); + predList.emplace_back(pos, builder.getOperandCount(operands.size())); + predList.emplace_back(pos, builder.getResultCount(results.size())); + + // Recurse into any attributes, operands, or results. + for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { + getTreePredicates( + predList, std::get<1>(it), builder, inputs, + builder.getAttribute(opPos, + std::get<0>(it).cast().getValue())); + } + for (auto operandIt : llvm::enumerate(operands)) + getTreePredicates(predList, operandIt.value(), builder, inputs, + builder.getOperand(opPos, operandIt.index())); + + // Only recurse into results that are not referenced in the source tree. + for (auto resultIt : llvm::enumerate(results)) { + getTreePredicates(predList, resultIt.value(), builder, inputs, + builder.getResult(opPos, resultIt.index())); + } + break; + } + case Predicates::ResultPos: { + assert(val.getType().isa() && "expected value type"); + pdl::OperationOp parentOp = cast(val.getDefiningOp()); + + // Prevent traversing a null value. + predList.emplace_back(pos, builder.getIsNotNull()); + + // Traverse the type constraint. + unsigned resultNo = cast(pos)->getResultNumber(); + getTreePredicates(predList, parentOp.types()[resultNo], builder, inputs, + builder.getType(pos)); + break; + } + case Predicates::TypePos: { + assert(val.getType().isa() && "expected value type"); + pdl::TypeOp typeOp = cast(val.getDefiningOp()); + + // Check for a constraint on a constant type. + if (Optional type = typeOp.type()) + predList.emplace_back(pos, builder.getTypeConstraint(*type)); + break; + } + default: + llvm_unreachable("unknown position kind"); + } +} + +/// Collect all of the predicates related to constraints within the given +/// pattern operation. +static void collectConstraintPredicates( + pdl::PatternOp pattern, std::vector &predList, + PredicateBuilder &builder, DenseMap &inputs) { + for (auto op : pattern.body().getOps()) { + OperandRange arguments = op.args(); + ArrayAttr parameters = op.constParamsAttr(); + + std::vector allPositions; + allPositions.reserve(arguments.size()); + for (Value arg : arguments) + allPositions.push_back(inputs.lookup(arg)); + + // Push the constraint to the furthest position. + Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), + comparePosDepth); + PredicateBuilder::Predicate pred = + builder.getConstraint(op.name(), std::move(allPositions), parameters); + predList.emplace_back(pos, pred); + } +} + +/// Given a pattern operation, build the set of matcher predicates necessary to +/// match this pattern. +static void buildPredicateList(pdl::PatternOp pattern, + PredicateBuilder &builder, + std::vector &predList, + DenseMap &valueToPosition) { + getTreePredicates(predList, pattern.getRewriter().root(), builder, + valueToPosition, builder.getRoot()); + collectConstraintPredicates(pattern, predList, builder, valueToPosition); +} + +//===----------------------------------------------------------------------===// +// Pattern Predicate Tree Merging +//===----------------------------------------------------------------------===// + +namespace { + +/// This class represents a specific predicate applied to a position, and +/// provides hashing and ordering operators. This class allows for computing a +/// frequence sum and ordering predicates based on a cost model. +struct OrderedPredicate { + OrderedPredicate(const std::pair &ip) + : position(ip.first), question(ip.second) {} + OrderedPredicate(const PositionalPredicate &ip) + : position(ip.position), question(ip.question) {} + + /// The position this predicate is applied to. + Position *position; + + /// The question that is applied by this predicate onto the position. + Qualifier *question; + + /// The first and second order benefit sums. + /// The primary sum is the number of occurrences of this predicate among all + /// of the patterns. + unsigned primary = 0; + /// The secondary sum is a squared summation of the primary sum of all of the + /// predicates within each pattern that contains this predicate. This allows + /// for favoring predicates that are more commonly shared within a pattern, as + /// opposed to those shared across patterns. + unsigned secondary = 0; + + /// A map between a pattern operation and the answer to the predicate question + /// within that pattern. + DenseMap patternToAnswer; + + /// Returns true if this predicate is ordered before `other`, based on the + /// cost model. + bool operator<(const OrderedPredicate &other) const { + // Sort by: + // * first and secondary order sums + // * lower depth + // * position dependency + // * predicate dependency. + auto *otherPos = other.position; + return std::make_tuple(other.primary, other.secondary, + otherPos->getIndex().size(), otherPos->getKind(), + other.question->getKind()) > + std::make_tuple(primary, secondary, position->getIndex().size(), + position->getKind(), question->getKind()); + } +}; + +/// A DenseMapInfo for OrderedPredicate based solely on the position and +/// question. +struct OrderedPredicateDenseInfo { + using Base = DenseMapInfo>; + + static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); } + static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); } + static bool isEqual(const OrderedPredicate &lhs, + const OrderedPredicate &rhs) { + return lhs.position == rhs.position && lhs.question == rhs.question; + } + static unsigned getHashValue(const OrderedPredicate &p) { + return llvm::hash_combine(p.position, p.question); + } +}; + +/// This class wraps a set of ordered predicates that are used within a specific +/// pattern operation. +struct OrderedPredicateList { + OrderedPredicateList(pdl::PatternOp pattern) : pattern(pattern) {} + + pdl::PatternOp pattern; + DenseSet predicates; +}; +} // end anonymous namespace + +/// Returns true if the given matcher refers to the same predicate as the given +/// ordered predicate. This means that the position and questions of the two +/// match. +static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) { + return node->getPosition() == predicate->position && + node->getQuestion() == predicate->question; +} + +/// Get or insert a child matcher for the given parent switch node, given a +/// predicate and parent pattern. +std::unique_ptr &getOrCreateChild(SwitchNode *node, + OrderedPredicate *predicate, + pdl::PatternOp pattern) { + assert(isSamePredicate(node, predicate) && + "expected matcher to equal the given predicate"); + + auto it = predicate->patternToAnswer.find(pattern); + assert(it != predicate->patternToAnswer.end() && + "expected pattern to exist in predicate"); + return node->getChildren().insert({it->second, nullptr}).first->second; +} + +/// Build the matcher CFG by "pushing" patterns through by sorted predicate +/// order. A pattern will traverse as far as possible using common predicates +/// and then either diverge from the CFG or reach the end of a branch and start +/// creating new nodes. +static void propagatePattern(std::unique_ptr &node, + OrderedPredicateList &list, + std::vector::iterator current, + std::vector::iterator end) { + if (current == end) { + // We've hit the end of a pattern, so create a successful result node. + node = std::make_unique(list.pattern, std::move(node)); + + // If the pattern doesn't contain this predicate, ignore it. + } else if (list.predicates.find(*current) == list.predicates.end()) { + propagatePattern(node, list, std::next(current), end); + + // If the current matcher node is invalid, create a new one for this + // position and continue propagation. + } else if (!node) { + // Create a new node at this position and continue + node = std::make_unique((*current)->position, + (*current)->question); + propagatePattern( + getOrCreateChild(cast(&*node), *current, list.pattern), + list, std::next(current), end); + + // If the matcher has already been created, and it is for this predicate we + // continue propagation to the child. + } else if (isSamePredicate(node.get(), *current)) { + propagatePattern( + getOrCreateChild(cast(&*node), *current, list.pattern), + list, std::next(current), end); + + // If the matcher doesn't match the current predicate, insert a branch as + // the common set of matchers has diverged. + } else { + propagatePattern(node->getFailureNode(), list, current, end); + } +} + +/// Fold any switch nodes nested under `node` to boolean nodes when possible. +/// `node` is updated in-place if it is a switch. +static void foldSwitchToBool(std::unique_ptr &node) { + if (!node) + return; + + if (SwitchNode *switchNode = dyn_cast(&*node)) { + SwitchNode::ChildMapT &children = switchNode->getChildren(); + for (auto &it : children) + foldSwitchToBool(it.second); + + // If the node only contains one child, collapse it into a boolean predicate + // node. + if (children.size() == 1) { + auto childIt = children.begin(); + node = std::make_unique( + node->getPosition(), node->getQuestion(), childIt->first, + std::move(childIt->second), std::move(node->getFailureNode())); + } + } else if (BoolNode *boolNode = dyn_cast(&*node)) { + foldSwitchToBool(boolNode->getSuccessNode()); + } + + foldSwitchToBool(node->getFailureNode()); +} + +/// Insert an exit node at the end of the failure path of the `root`. +static void insertExitNode(std::unique_ptr *root) { + while (*root) + root = &(*root)->getFailureNode(); + *root = std::make_unique(); +} + +/// Given a module containing PDL pattern operations, generate a matcher tree +/// using the patterns within the given module and return the root matcher node. +std::unique_ptr +MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, + DenseMap &valueToPosition) { + // Collect the set of predicates contained within the pattern operations of + // the module. + SmallVector>, 16> + patternsAndPredicates; + for (pdl::PatternOp pattern : module.getOps()) { + std::vector predicateList; + buildPredicateList(pattern, builder, predicateList, valueToPosition); + patternsAndPredicates.emplace_back(pattern, std::move(predicateList)); + } + + // Associate a pattern result with each unique predicate. + DenseSet uniqued; + for (auto &patternAndPredList : patternsAndPredicates) { + for (auto &predicate : patternAndPredList.second) { + auto it = uniqued.insert(predicate); + it.first->patternToAnswer.try_emplace(patternAndPredList.first, + predicate.answer); + } + } + + // Associate each pattern to a set of its ordered predicates for later lookup. + std::vector lists; + lists.reserve(patternsAndPredicates.size()); + for (auto &patternAndPredList : patternsAndPredicates) { + OrderedPredicateList list(patternAndPredList.first); + for (auto &predicate : patternAndPredList.second) { + OrderedPredicate *orderedPredicate = &*uniqued.find(predicate); + list.predicates.insert(orderedPredicate); + + // Increment the primary sum for each reference to a particular predicate. + ++orderedPredicate->primary; + } + lists.push_back(std::move(list)); + } + + // For a particular pattern, get the total primary sum and add it to the + // secondary sum of each predicate. Square the primary sums to emphasize + // shared predicates within rather than across patterns. + for (auto &list : lists) { + unsigned total = 0; + for (auto *predicate : list.predicates) + total += predicate->primary * predicate->primary; + for (auto *predicate : list.predicates) + predicate->secondary += total; + } + + // Sort the set of predicates now that the cost primary and secondary sums + // have been computed. + std::vector ordered; + ordered.reserve(uniqued.size()); + for (auto &ip : uniqued) + ordered.push_back(&ip); + std::stable_sort( + ordered.begin(), ordered.end(), + [](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; }); + + // Build the matchers for each of the pattern predicate lists. + std::unique_ptr root; + for (OrderedPredicateList &list : lists) + propagatePattern(root, list, ordered.begin(), ordered.end()); + + // Collapse the graph and insert the exit node. + foldSwitchToBool(root); + insertExitNode(&root); + return root; +} + +//===----------------------------------------------------------------------===// +// MatcherNode +//===----------------------------------------------------------------------===// + +MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q, + std::unique_ptr failureNode) + : position(p), question(q), failureNode(std::move(failureNode)), + matcherTypeID(matcherTypeID) {} + +//===----------------------------------------------------------------------===// +// BoolNode +//===----------------------------------------------------------------------===// + +BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer, + std::unique_ptr successNode, + std::unique_ptr failureNode) + : MatcherNode(TypeID::get(), position, question, + std::move(failureNode)), + answer(answer), successNode(std::move(successNode)) {} + +//===----------------------------------------------------------------------===// +// SuccessNode +//===----------------------------------------------------------------------===// + +SuccessNode::SuccessNode(pdl::PatternOp pattern, + std::unique_ptr failureNode) + : MatcherNode(TypeID::get(), /*position=*/nullptr, + /*question=*/nullptr, std::move(failureNode)), + pattern(pattern) {} + +//===----------------------------------------------------------------------===// +// SwitchNode +//===----------------------------------------------------------------------===// + +SwitchNode::SwitchNode(Position *position, Qualifier *question) + : MatcherNode(TypeID::get(), position, question) {} diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -33,6 +33,10 @@ class NVVMDialect; } // end namespace NVVM +namespace pdl_interp { +class PDLInterpDialect; +} // end namespace pdl_interp + namespace ROCDL { class ROCDLDialect; } // end namespace ROCDL diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -0,0 +1,145 @@ +// RUN: mlir-opt -split-input-file -convert-pdl-to-pdl-interp %s | FileCheck %s + +// CHECK-LABEL: module @empty_module +module @empty_module { +// CHECK: func @matcher(%{{.*}}: !pdl.operation) +// CHECK-NEXT: pdl_interp.finalize +} + +// ----- + +// CHECK-LABEL: module @simple +module @simple { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: pdl_interp.check_operation_name of %[[ROOT]] is "foo.op" -> ^bb2, ^bb1 + // CHECK: ^bb1: + // CHECK: pdl_interp.finalize + // CHECK: ^bb2: + // CHECK: pdl_interp.check_operand_count of %[[ROOT]] is 0 -> ^bb3, ^bb1 + // CHECK: ^bb3: + // CHECK: pdl_interp.check_result_count of %[[ROOT]] is 0 -> ^bb4, ^bb1 + // CHECK: ^bb4: + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter + // CHECK-SAME: benefit(1), loc([%[[ROOT]]]), root("foo.op") -> ^bb1 + + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter(%[[REWRITE_ROOT:.*]]: !pdl.operation) + // CHECK: pdl_interp.apply_rewrite "rewriter" on %[[REWRITE_ROOT]] + // CHECK: pdl_interp.finalize + pdl.pattern : benefit(1) { + %root = pdl.operation "foo.op"() + pdl.rewrite %root with "rewriter" + } +} + +// ----- + +// CHECK-LABEL: module @attributes +module @attributes { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // Check the value of "attr". + // CHECK-DAG: %[[ATTR:.*]] = pdl_interp.get_attribute "attr" of %[[ROOT]] + // CHECK-DAG: pdl_interp.is_not_null %[[ATTR]] : !pdl.attribute + // CHECK-DAG: pdl_interp.check_attribute %[[ATTR]] is 10 : i64 + + // Check the type of "attr1". + // CHECK-DAG: %[[ATTR1:.*]] = pdl_interp.get_attribute "attr1" of %[[ROOT]] + // CHECK-DAG: pdl_interp.is_not_null %[[ATTR1]] : !pdl.attribute + // CHECK-DAG: %[[ATTR1_TYPE:.*]] = pdl_interp.get_attribute_type of %[[ATTR1]] + // CHECK-DAG: pdl_interp.check_type %[[ATTR1_TYPE]] is i64 + pdl.pattern : benefit(1) { + %type = pdl.type : i64 + %attr = pdl.attribute 10 : i64 + %attr1 = pdl.attribute : %type + %root = pdl.operation {"attr" = %attr, "attr1" = %attr1} + pdl.rewrite %root with "rewriter" + } +} + +// ----- + +// CHECK-LABEL: module @constraints +module @constraints { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK-DAG: %[[INPUT:.*]] = pdl_interp.get_operand 0 of %[[ROOT]] + // CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]] + // CHECK: pdl_interp.apply_constraint "multi_constraint" [true](%[[INPUT]], %[[INPUT1]] : !pdl.value, !pdl.value) + + pdl.pattern : benefit(1) { + %input0 = pdl.input + %input1 = pdl.input + + pdl.apply_constraint "multi_constraint"[true](%input0, %input1 : !pdl.value, !pdl.value) + + %root = pdl.operation(%input0, %input1) + pdl.rewrite %root with "rewriter" + } +} + +// ----- + +// CHECK-LABEL: module @inputs +module @inputs { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT]] is 2 + + // Get the input and check the type. + // CHECK-DAG: %[[INPUT:.*]] = pdl_interp.get_operand 0 of %[[ROOT]] + // CHECK-DAG: pdl_interp.is_not_null %[[INPUT]] : !pdl.value + // CHECK-DAG: %[[INPUT_TYPE:.*]] = pdl_interp.get_value_type of %[[INPUT]] + // CHECK-DAG: pdl_interp.check_type %[[INPUT_TYPE]] is i64 + + // Get the second operand and check that it is equal to the first. + // CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]] + // CHECK-DAG: pdl_interp.are_equal %[[INPUT]], %[[INPUT1]] : !pdl.value + pdl.pattern : benefit(1) { + %type = pdl.type : i64 + %input = pdl.input : %type + %root = pdl.operation(%input, %input) + pdl.rewrite %root with "rewriter" + } +} + +// ----- + +// CHECK-LABEL: module @results +module @results { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: pdl_interp.check_result_count of %[[ROOT]] is 2 + + // Get the input and check the type. + // CHECK-DAG: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]] + // CHECK-DAG: pdl_interp.is_not_null %[[RESULT]] : !pdl.value + // CHECK-DAG: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]] + // CHECK-DAG: pdl_interp.check_type %[[RESULT_TYPE]] is i32 + + // Get the second operand and check that it is equal to the first. + // CHECK-DAG: %[[RESULT1:.*]] = pdl_interp.get_result 1 of %[[ROOT]] + // CHECK-NOT: pdl_interp.get_value_type of %[[RESULT1]] + pdl.pattern : benefit(1) { + %type1 = pdl.type : i32 + %type2 = pdl.type + %root, %results:2 = pdl.operation -> %type1, %type2 + pdl.rewrite %root with "rewriter" + } +} + +// ----- + +// CHECK-LABEL: module @switch_result_types +module @switch_result_types { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]] + // CHECK: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]] + // CHECK: pdl_interp.switch_type %[[RESULT_TYPE]] to [i32, i64] + pdl.pattern : benefit(1) { + %type = pdl.type : i32 + %root, %result = pdl.operation -> %type + pdl.rewrite %root with "rewriter" + } + pdl.pattern : benefit(1) { + %type = pdl.type : i64 + %root, %result = pdl.operation -> %type + pdl.rewrite %root with "rewriter" + } +} diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir @@ -0,0 +1,202 @@ +// RUN: mlir-opt -split-input-file -convert-pdl-to-pdl-interp %s | FileCheck %s + +// ----- + +// CHECK-LABEL: module @external +module @external { + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation, %[[INPUT:.*]]: !pdl.value) + // CHECK: pdl_interp.apply_rewrite "rewriter" [true](%[[INPUT]] : !pdl.value) on %[[ROOT]] + pdl.pattern : benefit(1) { + %input = pdl.input + %root = pdl.operation "foo.op"(%input) + pdl.rewrite %root with "rewriter"[true](%input : !pdl.value) + } +} + +// ----- + +// CHECK-LABEL: module @erase +module @erase { + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation) + // CHECK: pdl_interp.erase %[[ROOT]] + // CHECK: pdl_interp.finalize + pdl.pattern : benefit(1) { + %root = pdl.operation "foo.op" + pdl.rewrite %root { + pdl.erase %root + } + } +} + +// ----- + +// CHECK-LABEL: module @operation_attributes +module @operation_attributes { + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter(%[[ATTR:.*]]: !pdl.attribute, %[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[ATTR1:.*]] = pdl_interp.create_attribute true + // CHECK: pdl_interp.create_operation "foo.op"() {"attr" = %[[ATTR]], "attr1" = %[[ATTR1]]} + pdl.pattern : benefit(1) { + %attr = pdl.attribute + %root = pdl.operation "foo.op" {"attr" = %attr} + pdl.rewrite %root { + %attr1 = pdl.attribute true + %newOp = pdl.operation "foo.op" {"attr" = %attr, "attr1" = %attr1} + pdl.erase %root + } + } +} + +// ----- + +// CHECK-LABEL: module @operation_operands +module @operation_operands { + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter(%[[OPERAND:.*]]: !pdl.value, %[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation "foo.op"(%[[OPERAND]]) + // CHECK: %[[OPERAND1:.*]] = pdl_interp.get_result 0 of %[[NEWOP]] + // CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]]) + pdl.pattern : benefit(1) { + %operand = pdl.input + %root = pdl.operation "foo.op"(%operand) + pdl.rewrite %root { + %type = pdl.type : i32 + %newOp, %result = pdl.operation "foo.op"(%operand) -> %type + %newOp1 = pdl.operation "foo.op2"(%result) + pdl.erase %root + } + } +} + +// ----- + +// CHECK-LABEL: module @operation_operands +module @operation_operands { + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter(%[[OPERAND:.*]]: !pdl.value, %[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation "foo.op"(%[[OPERAND]]) + // CHECK: %[[OPERAND1:.*]] = pdl_interp.get_result 0 of %[[NEWOP]] + // CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]]) + pdl.pattern : benefit(1) { + %operand = pdl.input + %root = pdl.operation "foo.op"(%operand) + pdl.rewrite %root { + %type = pdl.type : i32 + %newOp, %result = pdl.operation "foo.op"(%operand) -> %type + %newOp1 = pdl.operation "foo.op2"(%result) + pdl.erase %root + } + } +} + +// ----- + +// CHECK-LABEL: module @operation_result_types +module @operation_result_types { + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter(%[[TYPE:.*]]: !pdl.type, %[[TYPE1:.*]]: !pdl.type + // CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]], %[[TYPE1]] + pdl.pattern : benefit(1) { + %rootType = pdl.type + %rootType1 = pdl.type + %root, %results:2 = pdl.operation "foo.op" -> %rootType, %rootType1 + pdl.rewrite %root { + %newType1 = pdl.type + %newOp, %newResults:2 = pdl.operation "foo.op" -> %rootType, %newType1 + pdl.replace %root with %newOp + } + } +} + +// ----- + +// CHECK-LABEL: module @operation_result_types_infer_from_value_replacement +module @operation_result_types_infer_from_value_replacement { + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter(%[[TYPE:.*]]: !pdl.type + // CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]] + pdl.pattern : benefit(1) { + %rootType = pdl.type + %root, %result = pdl.operation "foo.op" -> %rootType + pdl.rewrite %root { + %newType = pdl.type + %newOp, %newResult = pdl.operation "foo.op" -> %newType + pdl.replace %root with (%newResult) + } + } +} +// ----- + +// CHECK-LABEL: module @replace_with_op +module @replace_with_op { + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation + // CHECK: %[[OP_RESULT:.*]] = pdl_interp.get_result 0 of %[[NEWOP]] + // CHECK: pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]]) + pdl.pattern : benefit(1) { + %type = pdl.type : i32 + %root, %result = pdl.operation "foo.op" -> %type + pdl.rewrite %root { + %newOp, %newResult = pdl.operation "foo.op" -> %type + pdl.replace %root with %newOp + } + } +} + +// ----- + +// CHECK-LABEL: module @replace_with_values +module @replace_with_values { + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation + // CHECK: %[[OP_RESULT:.*]] = pdl_interp.get_result 0 of %[[NEWOP]] + // CHECK: pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]]) + pdl.pattern : benefit(1) { + %type = pdl.type : i32 + %root, %result = pdl.operation "foo.op" -> %type + pdl.rewrite %root { + %newOp, %newResult = pdl.operation "foo.op" -> %type + pdl.replace %root with (%newResult) + } + } +} + +// ----- + +// CHECK-LABEL: module @replace_with_no_results +module @replace_with_no_results { + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation) + // CHECK: pdl_interp.create_operation "foo.op" + // CHECK: pdl_interp.erase %[[ROOT]] + pdl.pattern : benefit(1) { + %root = pdl.operation "foo.op" + pdl.rewrite %root { + %newOp = pdl.operation "foo.op" + pdl.replace %root with %newOp + } + } +} + +// ----- + +// CHECK-LABEL: module @create_native +module @create_native { + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[TYPE:.*]] = pdl_interp.create_native "functor" [true](%[[ROOT]] : !pdl.operation) : !pdl.type + // CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]] + pdl.pattern : benefit(1) { + %type = pdl.type + %root, %result = pdl.operation "foo.op" -> %type + pdl.rewrite %root { + %newType = pdl.create_native "functor"[true](%root : !pdl.operation) : !pdl.type + %newOp, %newResult = pdl.operation "foo.op" -> %newType + pdl.replace %root with %newOp + } + } +}