diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td @@ -48,7 +48,7 @@ %resultType = pdl.type %inputOperand = pdl.operand - %root, %results = pdl.operation "foo.op"(%inputOperand) -> %resultType + %root = pdl.operation "foo.op"(%inputOperand) -> %resultType pdl.rewrite %root { pdl.replace %root with (%inputOperand) } diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -177,7 +177,7 @@ let description = [{ `pdl.operand` operations capture external operand edges into an operation node that originate from operations or block arguments not otherwise - specified within the pattern (e.g. via `pdl.operation`). These operations + specified within the pattern (e.g. via `pdl.result`). These operations define individual operands of a given operation. A `pdl.operand` may partially constrain an operand by specifying an expected value type (via a `pdl.type` operation). @@ -223,8 +223,8 @@ `pdl.operation`s are composed of a name, and a set of attribute, operand, and result type values, that map to what those that would be on a constructed instance of that operation. The results of a `pdl.operation` are - a handle to the operation itself, and a handle to each of the operation - result values. + a handle to the operation itself. Handles to the results of the operation + can be extracted via `pdl.result`. When used within a matching context, the name of the operation may be omitted. @@ -241,7 +241,7 @@ ```mlir // Define an instance of a `foo.op` operation. - %op, %results:4 = pdl.operation "foo.op"(%arg0, %arg1) {"attrA" = %attr0} -> %type, %type, %type, %type + %op = pdl.operation "foo.op"(%arg0, %arg1) {"attrA" = %attr0} -> %type, %type, %type, %type ``` }]; @@ -250,8 +250,13 @@ Variadic:$attributes, StrArrayAttr:$attributeNames, Variadic:$types); - let results = (outs PDL_Operation:$op, - Variadic:$results); + let results = (outs PDL_Operation:$op); + let assemblyFormat = [{ + ($name^)? (`(` $operands^ `)`)? + custom($attributes, $attributeNames) + (`->` $types^)? attr-dict + }]; + let builders = [ OpBuilder<(ins CArg<"Optional", "llvm::None">:$name, CArg<"ValueRange", "llvm::None">:$operandValues, @@ -259,10 +264,9 @@ CArg<"ValueRange", "llvm::None">:$attrValues, CArg<"ValueRange", "llvm::None">:$resultTypes), [{ auto nameAttr = name ? StringAttr() : $_builder.getStringAttr(*name); - build($_builder, $_state, $_builder.getType(), {}, nameAttr, + build($_builder, $_state, $_builder.getType(), nameAttr, operandValues, attrValues, $_builder.getStrArrayAttr(attrNames), resultTypes); - $_state.types.append(resultTypes.size(), $_builder.getType()); }]>, ]; let extraClassDeclaration = [{ @@ -293,7 +297,7 @@ pdl.pattern : benefit(1) { %resultType = pdl.type %inputOperand = pdl.operand - %root, %results = pdl.operation "foo.op"(%inputOperand) -> (%resultType) + %root = pdl.operation "foo.op"(%inputOperand) -> (%resultType) pdl.rewrite %root { pdl.replace %root with (%inputOperand) } @@ -368,6 +372,39 @@ }]; } +//===----------------------------------------------------------------------===// +// pdl::ResultOp +//===----------------------------------------------------------------------===// + +def PDL_ResultOp : PDL_Op<"result"> { + let summary = "Extract a result from an operation"; + let description = [{ + `pdl.result` operations extract result edges from an operation node within + a pattern or rewrite region. The provided index is zero-based, and + represents the concrete result to extract, i.e. this is not the result index + as defined by the ODS definition of the operation. + + Example: + + ```mlir + // Extract a result: + %operation = pdl.operation ... + %result = pdl.result 1 of %operation + + // Imagine the following IR being matched: + %result_0, %result_1 = foo.op ... + + // If the example pattern snippet above were matching against `foo.op` in + // the IR snippted, `%result` would correspond to `%result_1`. + ``` + }]; + + let arguments = (ins PDL_Operation:$parent, I32Attr:$index); + let results = (outs PDL_Value:$val); + let assemblyFormat = "$index `of` $parent attr-dict"; + let verifier = ?; +} + //===----------------------------------------------------------------------===// // pdl::RewriteOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -85,6 +85,9 @@ void generateRewriter(pdl::ReplaceOp replaceOp, DenseMap &rewriteValues, function_ref mapRewriteValue); + void generateRewriter(pdl::ResultOp resultOp, + DenseMap &rewriteValues, + function_ref mapRewriteValue); void generateRewriter(pdl::TypeOp typeOp, DenseMap &rewriteValues, function_ref mapRewriteValue); @@ -457,9 +460,10 @@ for (Operation &rewriteOp : *rewriter.getBody()) { llvm::TypeSwitch(&rewriteOp) .Case([&](auto op) { - this->generateRewriter(op, rewriteValues, mapRewriteValue); - }); + pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::TypeOp>( + [&](auto op) { + this->generateRewriter(op, rewriteValues, mapRewriteValue); + }); } } @@ -511,17 +515,15 @@ operationOp.attributeNames()); rewriteValues[operationOp.op()] = createdOp; - // Make all of the new operation results available. - OperandRange resultTypes = operationOp.types(); - for (auto it : llvm::enumerate(operationOp.results())) { + // Generate accesses for any results that have their types constrained. + for (auto it : llvm::enumerate(operationOp.types())) { + Value &type = rewriteValues[it.value()]; + if (type) + continue; + Value getResultVal = builder.create( loc, builder.getType(), createdOp, it.index()); - rewriteValues[it.value()] = getResultVal; - - // If any of the types have not been resolved, make those available as well. - Value &type = rewriteValues[resultTypes[it.index()]]; - if (!type) - type = builder.create(loc, getResultVal); + type = builder.create(loc, getResultVal); } } @@ -540,29 +542,41 @@ void PatternLowering::generateRewriter( pdl::ReplaceOp replaceOp, DenseMap &rewriteValues, function_ref mapRewriteValue) { + SmallVector replOperands; + // If the replacement was another operation, get its results. `pdl` allows // for using an operation for simplicitly, but the interpreter isn't as // user facing. - ValueRange origOperands; - if (Value replOp = replaceOp.replOperation()) - origOperands = cast(replOp.getDefiningOp()).results(); - else - origOperands = replaceOp.replValues(); + if (Value replOp = replaceOp.replOperation()) { + pdl::OperationOp op = cast(replOp.getDefiningOp()); + for (unsigned i = 0, e = op.types().size(); i < e; ++i) + replOperands.push_back(builder.create( + replOp.getLoc(), builder.getType(), + mapRewriteValue(replOp), i)); + } else { + for (Value operand : replaceOp.replValues()) + replOperands.push_back(mapRewriteValue(operand)); + } // If there are no replacement values, just create an erase instead. - if (origOperands.empty()) { + if (replOperands.empty()) { builder.create(replaceOp.getLoc(), mapRewriteValue(replaceOp.operation())); return; } - SmallVector replOperands; - for (Value operand : origOperands) - replOperands.push_back(mapRewriteValue(operand)); builder.create( replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands); } +void PatternLowering::generateRewriter( + pdl::ResultOp resultOp, DenseMap &rewriteValues, + function_ref mapRewriteValue) { + rewriteValues[resultOp] = builder.create( + resultOp.getLoc(), builder.getType(), + mapRewriteValue(resultOp.parent()), resultOp.index()); +} + void PatternLowering::generateRewriter( pdl::TypeOp typeOp, DenseMap &rewriteValues, function_ref mapRewriteValue) { @@ -602,8 +616,8 @@ bool hasTypeInference = op.hasTypeInference(); auto resultTypeValues = op.types(); types.reserve(resultTypeValues.size()); - for (auto it : llvm::enumerate(op.results())) { - Value result = it.value(), resultType = resultTypeValues[it.index()]; + for (auto it : llvm::enumerate(resultTypeValues)) { + Value resultType = it.value(); // Check for an already translated value. if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { @@ -633,16 +647,11 @@ if ((replacedOp = getReplacedOperationFrom(use))) break; fullReplacedOperation = replacedOp; + assert(fullReplacedOperation && + "expected replaced op to infer a result type from"); } else { replacedOp = fullReplacedOperation.getValue(); } - // Infer from the result, as there was no fully replaced op. - if (!replacedOp) { - for (OpOperand &use : result.getUses()) - if ((replacedOp = getReplacedOperationFrom(use))) - break; - assert(replacedOp && "expected replaced op to infer a result type from"); - } auto replOpOp = cast(replacedOp); types.push_back(mapRewriteValue(replOpOp.types()[it.index()])); diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -433,7 +433,7 @@ Position *getRoot() { return OperationPosition::getRoot(uniquer); } /// Returns the parent position defining the value held by the given operand. - Position *getParent(OperandPosition *p) { + OperationPosition *getParent(OperandPosition *p) { std::vector index = p->getIndex(); index.push_back(p->getOperandNumber()); return OperationPosition::get(uniquer, index); diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::pdl_to_pdl_interp; @@ -20,151 +21,181 @@ // Predicate List Building //===----------------------------------------------------------------------===// +static void getTreePredicates(std::vector &predList, + Value val, PredicateBuilder &builder, + DenseMap &inputs, + Position *pos); + /// Compares the depths of two positions. static bool comparePosDepth(Position *lhs, Position *rhs) { return lhs->getIndex().size() < rhs->getIndex().size(); } -/// Collect the tree predicates anchored at the given value. static void getTreePredicates(std::vector &predList, Value val, PredicateBuilder &builder, DenseMap &inputs, - Position *pos) { - // Make sure this input value is accessible to the rewrite. - auto it = inputs.try_emplace(val, pos); + AttributePosition *pos) { + assert(val.getType().isa() && "expected attribute type"); + pdl::AttributeOp attr = cast(val.getDefiningOp()); + predList.emplace_back(pos, builder.getIsNotNull()); + + // If the attribute has a type or value, add a constraint. + if (Value type = attr.type()) + getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); + else if (Attribute value = attr.valueAttr()) + predList.emplace_back(pos, builder.getAttributeConstraint(value)); +} - // If this is an input value that has been visited in the tree, add a - // constraint to ensure that both instances refer to the same value. - if (!it.second && - isa(val.getDefiningOp())) { - auto minMaxPositions = std::minmax(pos, it.first->second, comparePosDepth); - predList.emplace_back(minMaxPositions.second, - builder.getEqualTo(minMaxPositions.first)); - return; - } +static void getTreePredicates(std::vector &predList, + Value val, PredicateBuilder &builder, + DenseMap &inputs, + OperandPosition *pos) { + assert(val.getType().isa() && "expected value type"); - // Check for a per-position predicate to apply. - switch (pos->getKind()) { - case Predicates::AttributePos: { - assert(val.getType().isa() && - "expected attribute type"); - pdl::AttributeOp attr = cast(val.getDefiningOp()); - predList.emplace_back(pos, builder.getIsNotNull()); + // Prevent traversal into a null value. + predList.emplace_back(pos, builder.getIsNotNull()); - // If the attribute has a type, add a type constraint. - if (Value type = attr.type()) { + // If this is a typed operand, add a type constraint. + if (auto in = val.getDefiningOp()) { + if (Value type = in.type()) getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); - // Check for a constant value of the attribute. - } else if (Optional value = attr.value()) { - predList.emplace_back(pos, builder.getAttributeConstraint(*value)); - } - break; + // Otherwise, recurse into a result node. + } else if (auto resultOp = val.getDefiningOp()) { + OperationPosition *parentPos = builder.getParent(pos); + Position *resultPos = builder.getResult(parentPos, resultOp.index()); + predList.emplace_back(parentPos, builder.getIsNotNull()); + predList.emplace_back(resultPos, builder.getEqualTo(pos)); + getTreePredicates(predList, resultOp.parent(), builder, inputs, parentPos); } - case Predicates::OperandPos: { - assert(val.getType().isa() && "expected value type"); +} + +static void getTreePredicates(std::vector &predList, + Value val, PredicateBuilder &builder, + DenseMap &inputs, + OperationPosition *pos) { + assert(val.getType().isa() && "expected operation"); + pdl::OperationOp op = cast(val.getDefiningOp()); + OperationPosition *opPos = cast(pos); - // Prevent traversal into a null value. + // Ensure getDefiningOp returns a non-null operation. + if (!opPos->isRoot()) predList.emplace_back(pos, builder.getIsNotNull()); - // If this is a typed operand, add a type constraint. - if (auto in = val.getDefiningOp()) { - if (Value type = in.type()) { - getTreePredicates(predList, type, builder, inputs, - builder.getType(pos)); - } - - // Otherwise, recurse into the parent node. - } else if (auto parentOp = val.getDefiningOp()) { - getTreePredicates(predList, parentOp.op(), builder, inputs, - builder.getParent(cast(pos))); - } - break; + // Check that this is the correct root operation. + if (Optional opName = op.name()) + predList.emplace_back(pos, builder.getOperationName(*opName)); + + // Check that the operation has the proper number of operands and results. + OperandRange operands = op.operands(); + OperandRange types = op.types(); + predList.emplace_back(pos, builder.getOperandCount(operands.size())); + predList.emplace_back(pos, builder.getResultCount(types.size())); + + // Recurse into any attributes, operands, or results. + for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { + getTreePredicates( + predList, std::get<1>(it), builder, inputs, + builder.getAttribute(opPos, + std::get<0>(it).cast().getValue())); } - case Predicates::OperationPos: { - assert(val.getType().isa() && "expected operation"); - pdl::OperationOp op = cast(val.getDefiningOp()); - OperationPosition *opPos = cast(pos); - - // Ensure getDefiningOp returns a non-null operation. - if (!opPos->isRoot()) - predList.emplace_back(pos, builder.getIsNotNull()); - - // Check that this is the correct root operation. - if (Optional opName = op.name()) - predList.emplace_back(pos, builder.getOperationName(*opName)); - - // Check that the operation has the proper number of operands and results. - OperandRange operands = op.operands(); - ResultRange results = op.results(); - predList.emplace_back(pos, builder.getOperandCount(operands.size())); - predList.emplace_back(pos, builder.getResultCount(results.size())); - - // Recurse into any attributes, operands, or results. - for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { - getTreePredicates( - predList, std::get<1>(it), builder, inputs, - builder.getAttribute(opPos, - std::get<0>(it).cast().getValue())); - } - for (auto operandIt : llvm::enumerate(operands)) - getTreePredicates(predList, operandIt.value(), builder, inputs, - builder.getOperand(opPos, operandIt.index())); - - // Only recurse into results that are not referenced in the source tree. - for (auto resultIt : llvm::enumerate(results)) { - getTreePredicates(predList, resultIt.value(), builder, inputs, - builder.getResult(opPos, resultIt.index())); - } - break; + for (auto operandIt : llvm::enumerate(operands)) { + getTreePredicates(predList, operandIt.value(), builder, inputs, + builder.getOperand(opPos, operandIt.index())); + } + for (auto &resultIt : llvm::enumerate(types)) { + auto *resultPos = builder.getResult(pos, resultIt.index()); + predList.emplace_back(resultPos, builder.getIsNotNull()); + getTreePredicates(predList, resultIt.value(), builder, inputs, + builder.getType(resultPos)); } - case Predicates::ResultPos: { - assert(val.getType().isa() && "expected value type"); - pdl::OperationOp parentOp = cast(val.getDefiningOp()); +} - // Prevent traversing a null value. - predList.emplace_back(pos, builder.getIsNotNull()); +static void getTreePredicates(std::vector &predList, + Value val, PredicateBuilder &builder, + DenseMap &inputs, + TypePosition *pos) { + assert(val.getType().isa() && "expected value type"); + pdl::TypeOp typeOp = cast(val.getDefiningOp()); - // Traverse the type constraint. - unsigned resultNo = cast(pos)->getResultNumber(); - getTreePredicates(predList, parentOp.types()[resultNo], builder, inputs, - builder.getType(pos)); - break; - } - case Predicates::TypePos: { - assert(val.getType().isa() && "expected value type"); - pdl::TypeOp typeOp = cast(val.getDefiningOp()); - - // Check for a constraint on a constant type. - if (Optional type = typeOp.type()) - predList.emplace_back(pos, builder.getTypeConstraint(*type)); - break; - } - default: - llvm_unreachable("unknown position kind"); + // Check for a constraint on a constant type. + if (Optional type = typeOp.type()) + predList.emplace_back(pos, builder.getTypeConstraint(*type)); +} + +/// Collect the tree predicates anchored at the given value. +static void getTreePredicates(std::vector &predList, + Value val, PredicateBuilder &builder, + DenseMap &inputs, + Position *pos) { + // Make sure this input value is accessible to the rewrite. + auto it = inputs.try_emplace(val, pos); + if (!it.second) { + // If this is an input value that has been visited in the tree, add a + // constraint to ensure that both instances refer to the same value. + if (isa( + val.getDefiningOp())) { + auto minMaxPositions = + std::minmax(pos, it.first->second, comparePosDepth); + predList.emplace_back(minMaxPositions.second, + builder.getEqualTo(minMaxPositions.first)); + } + return; } + + TypeSwitch(pos) + .Case([&](auto *derivedPos) { + getTreePredicates(predList, val, builder, inputs, derivedPos); + }) + .Default([](auto *) { llvm_unreachable("unexpected position kind"); }); } /// Collect all of the predicates related to constraints within the given /// pattern operation. -static void collectConstraintPredicates( - pdl::PatternOp pattern, std::vector &predList, - PredicateBuilder &builder, DenseMap &inputs) { - for (auto op : pattern.body().getOps()) { - OperandRange arguments = op.args(); - ArrayAttr parameters = op.constParamsAttr(); - - std::vector allPositions; - allPositions.reserve(arguments.size()); - for (Value arg : arguments) - allPositions.push_back(inputs.lookup(arg)); - - // Push the constraint to the furthest position. - Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), - comparePosDepth); - PredicateBuilder::Predicate pred = - builder.getConstraint(op.name(), std::move(allPositions), parameters); - predList.emplace_back(pos, pred); +static void getConstraintPredicates(pdl::ApplyConstraintOp op, + std::vector &predList, + PredicateBuilder &builder, + DenseMap &inputs) { + OperandRange arguments = op.args(); + ArrayAttr parameters = op.constParamsAttr(); + + std::vector allPositions; + allPositions.reserve(arguments.size()); + for (Value arg : arguments) + allPositions.push_back(inputs.lookup(arg)); + + // Push the constraint to the furthest position. + Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), + comparePosDepth); + PredicateBuilder::Predicate pred = + builder.getConstraint(op.name(), std::move(allPositions), parameters); + predList.emplace_back(pos, pred); +} + +static void getResultPredicates(pdl::ResultOp op, + std::vector &predList, + PredicateBuilder &builder, + DenseMap &inputs) { + Position *&resultPos = inputs[op]; + if (resultPos) + return; + auto *parentPos = cast(inputs.lookup(op.parent())); + resultPos = builder.getResult(parentPos, op.index()); + predList.emplace_back(resultPos, builder.getIsNotNull()); +} + +/// Collect all of the predicates that cannot be determined via walking the +/// tree. +static void getNonTreePredicates(pdl::PatternOp pattern, + std::vector &predList, + PredicateBuilder &builder, + DenseMap &inputs) { + for (Operation &op : pattern.body().getOps()) { + if (auto constraintOp = dyn_cast(&op)) + getConstraintPredicates(constraintOp, predList, builder, inputs); + else if (auto resultOp = dyn_cast(&op)) + getResultPredicates(resultOp, predList, builder, inputs); } } @@ -176,7 +207,7 @@ DenseMap &valueToPosition) { getTreePredicates(predList, pattern.getRewriter().root(), builder, valueToPosition, builder.getRoot()); - collectConstraintPredicates(pattern, predList, builder, valueToPosition); + getNonTreePredicates(pattern, predList, builder, valueToPosition); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -31,21 +31,36 @@ >(); } +//===----------------------------------------------------------------------===// +// PDL Operations +//===----------------------------------------------------------------------===// + /// Returns true if the given operation is used by a "binding" pdl operation /// within the main matcher body of a `pdl.pattern`. +static bool hasBindingUseInMatcher(Operation *op, Block *matcherBlock) { + for (Operation *user : op->getUsers()) { + if (user->getBlock() != matcherBlock) + continue; + if (isa(user)) + return true; + // A result by itself is not binding, it must also be bound. + if (isa(user) && hasBindingUseInMatcher(user, matcherBlock)) + return true; + } + return false; +} + +/// Returns success if the given operation is used by a "binding" pdl operation +/// within the main matcher body of a `pdl.pattern`. On failure, emits an error +/// with the given context message. static LogicalResult verifyHasBindingUseInMatcher(Operation *op, StringRef bindableContextStr = "`pdl.operation`") { // If the pattern is not a pattern, there is nothing to do. if (!isa(op->getParentOp())) return success(); - Block *matcherBlock = op->getBlock(); - for (Operation *user : op->getUsers()) { - if (user->getBlock() != matcherBlock) - continue; - if (isa(user)) - return success(); - } + if (hasBindingUseInMatcher(op, op->getBlock())) + return success(); return op->emitOpError() << "expected a bindable (i.e. " << bindableContextStr << ") user when defined in the matcher body of a `pdl.pattern`"; @@ -89,37 +104,12 @@ // pdl::OperationOp //===----------------------------------------------------------------------===// -static ParseResult parseOperationOp(OpAsmParser &p, OperationState &state) { +static ParseResult parseOperationOpAttributes( + OpAsmParser &p, SmallVectorImpl &attrOperands, + ArrayAttr &attrNamesAttr) { Builder &builder = p.getBuilder(); - - // Parse the optional operation name. - bool startsWithOperands = succeeded(p.parseOptionalLParen()); - bool startsWithAttributes = - !startsWithOperands && succeeded(p.parseOptionalLBrace()); - bool startsWithOpName = false; - if (!startsWithAttributes && !startsWithOperands) { - StringAttr opName; - OptionalParseResult opNameResult = - p.parseOptionalAttribute(opName, "name", state.attributes); - startsWithOpName = opNameResult.hasValue(); - if (startsWithOpName && failed(*opNameResult)) - return failure(); - } - - // Parse the operands. - SmallVector operands; - if (startsWithOperands || - (!startsWithAttributes && succeeded(p.parseOptionalLParen()))) { - if (p.parseOperandList(operands) || p.parseRParen() || - p.resolveOperands(operands, builder.getType(), - state.operands)) - return failure(); - } - - // Parse the attributes. SmallVector attrNames; - if (startsWithAttributes || succeeded(p.parseOptionalLBrace())) { - SmallVector attrOps; + if (succeeded(p.parseOptionalLBrace())) { do { StringAttr nameAttr; OpAsmParser::OperandType operand; @@ -127,68 +117,29 @@ p.parseOperand(operand)) return failure(); attrNames.push_back(nameAttr); - attrOps.push_back(operand); + attrOperands.push_back(operand); } while (succeeded(p.parseOptionalComma())); - - if (p.parseRBrace() || - p.resolveOperands(attrOps, builder.getType(), - state.operands)) - return failure(); - } - state.addAttribute("attributeNames", builder.getArrayAttr(attrNames)); - state.addTypes(builder.getType()); - - // Parse the result types. - SmallVector opResultTypes; - if (succeeded(p.parseOptionalArrow())) { - if (p.parseOperandList(opResultTypes) || - p.resolveOperands(opResultTypes, builder.getType(), - state.operands)) + if (p.parseRBrace()) return failure(); - state.types.append(opResultTypes.size(), builder.getType()); } - - if (p.parseOptionalAttrDict(state.attributes)) - return failure(); - - int32_t operandSegmentSizes[] = {static_cast(operands.size()), - static_cast(attrNames.size()), - static_cast(opResultTypes.size())}; - state.addAttribute("operand_segment_sizes", - builder.getI32VectorAttr(operandSegmentSizes)); + attrNamesAttr = builder.getArrayAttr(attrNames); return success(); } -static void print(OpAsmPrinter &p, OperationOp op) { - p << "pdl.operation "; - if (Optional name = op.name()) - p << '"' << *name << '"'; - - auto operandValues = op.operands(); - if (!operandValues.empty()) - p << '(' << operandValues << ')'; - - // Emit the optional attributes. - ArrayAttr attrNames = op.attributeNames(); - if (!attrNames.empty()) { - Operation::operand_range attrArgs = op.attributes(); - p << " {"; - interleaveComma(llvm::seq(0, attrNames.size()), p, - [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); - p << '}'; - } - - // Print the result type constraints of the operation. - if (!op.results().empty()) - p << " -> " << op.types(); - p.printOptionalAttrDict(op->getAttrs(), - {"attributeNames", "name", "operand_segment_sizes"}); +static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op, + OperandRange attrArgs, + ArrayAttr attrNames) { + if (attrNames.empty()) + return; + p << " {"; + interleaveComma(llvm::seq(0, attrNames.size()), p, + [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); + p << '}'; } /// Verifies that the result types of this operation, defined within a /// `pdl.rewrite`, can be inferred. static LogicalResult verifyResultTypesAreInferrable(OperationOp op, - ResultRange opResults, OperandRange resultTypes) { // Functor that returns if the given use can be used to infer a type. Block *rewriterBlock = op->getBlock(); @@ -210,8 +161,8 @@ return success(); // Otherwise, make sure each of the types can be inferred. - for (int i : llvm::seq(0, opResults.size())) { - Operation *resultTypeOp = resultTypes[i].getDefiningOp(); + for (auto it : llvm::enumerate(resultTypes)) { + Operation *resultTypeOp = it.value().getDefiningOp(); assert(resultTypeOp && "expected valid result type operation"); // If the op was defined by a `create_native`, it is guaranteed to be @@ -232,14 +183,11 @@ if (llvm::any_of(typeOp.getResult().getUsers(), constrainsInputOp)) continue; - // Otherwise, check to see if any uses of the result can infer the type. - if (llvm::any_of(opResults[i].getUses(), canInferTypeFromUse)) - continue; return op .emitOpError("must have inferable or constrained result types when " "nested within `pdl.rewrite`") .attachNote() - .append("result type #", i, " was not constrained"); + .append("result type #", it.index(), " was not constrained"); } return success(); } @@ -259,19 +207,10 @@ << " values"; } - OperandRange resultTypes = op.types(); - auto opResults = op.results(); - if (resultTypes.size() != opResults.size()) { - return op.emitOpError() << "expected the same number of result values and " - "result type constraints, got " - << opResults.size() << " results and " - << resultTypes.size() << " constraints"; - } - // If the operation is within a rewrite body and doesn't have type inference, // ensure that the result types can be resolved. if (isWithinRewrite && !op.hasTypeInference()) { - if (failed(verifyResultTypesAreInferrable(op, opResults, resultTypes))) + if (failed(verifyResultTypesAreInferrable(op, op.types()))) return failure(); } @@ -344,37 +283,9 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(ReplaceOp op) { - auto sourceOp = cast(op.operation().getDefiningOp()); - auto sourceOpResults = sourceOp.results(); - auto replValues = op.replValues(); - - if (Value replOpVal = op.replOperation()) { - auto replOp = cast(replOpVal.getDefiningOp()); - auto replOpResults = replOp.results(); - if (sourceOpResults.size() != replOpResults.size()) { - return op.emitOpError() - << "expected source operation to have the same number of results " - "as the replacement operation, replacement operation provided " - << replOpResults.size() << " but expected " - << sourceOpResults.size(); - } - - if (!replValues.empty()) { - return op.emitOpError() << "expected no replacement values to be provided" - " when the replacement operation is present"; - } - - return success(); - } - - if (sourceOpResults.size() != replValues.size()) { - return op.emitOpError() - << "expected source operation to have the same number of results " - "as the provided replacement values, found " - << replValues.size() << " replacement values but expected " - << sourceOpResults.size(); - } - + if (op.replOperation() && !op.replValues().empty()) + return op.emitOpError() << "expected no replacement values to be provided" + " when the replacement operation is present"; return success(); } diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -63,15 +63,16 @@ // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) // CHECK-DAG: %[[INPUT:.*]] = pdl_interp.get_operand 0 of %[[ROOT]] // CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]] - // CHECK: pdl_interp.apply_constraint "multi_constraint" [true](%[[INPUT]], %[[INPUT1]] : !pdl.value, !pdl.value) + // CHECK-DAG: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]] + // CHECK: pdl_interp.apply_constraint "multi_constraint" [true](%[[INPUT]], %[[INPUT1]], %[[RESULT]] pdl.pattern : benefit(1) { %input0 = pdl.operand %input1 = pdl.operand - - pdl.apply_constraint "multi_constraint"[true](%input0, %input1 : !pdl.value, !pdl.value) - %root = pdl.operation(%input0, %input1) + %result0 = pdl.result 0 of %root + + pdl.apply_constraint "multi_constraint"[true](%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value) pdl.rewrite %root with "rewriter" } } @@ -107,19 +108,52 @@ // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) // CHECK: pdl_interp.check_result_count of %[[ROOT]] is 2 - // Get the input and check the type. + // Get the result and check the type. // CHECK-DAG: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]] // CHECK-DAG: pdl_interp.is_not_null %[[RESULT]] : !pdl.value // CHECK-DAG: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]] // CHECK-DAG: pdl_interp.check_type %[[RESULT_TYPE]] is i32 - // Get the second operand and check that it is equal to the first. - // CHECK-DAG: %[[RESULT1:.*]] = pdl_interp.get_result 1 of %[[ROOT]] - // CHECK-NOT: pdl_interp.get_value_type of %[[RESULT1]] + // The second result doesn't have any constraints, so we don't generate an + // access for it. + // CHECK-NOT: pdl_interp.get_result 1 of %[[ROOT]] pdl.pattern : benefit(1) { %type1 = pdl.type : i32 %type2 = pdl.type - %root, %results:2 = pdl.operation -> %type1, %type2 + %root = pdl.operation -> %type1, %type2 + pdl.rewrite %root with "rewriter" + } +} + +// ----- + +// CHECK-LABEL: module @results_as_operands +module @results_as_operands { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + + // Get the first result and check it matches the first operand. + // CHECK-DAG: %[[OPERAND_0:.*]] = pdl_interp.get_operand 0 of %[[ROOT]] + // CHECK-DAG: %[[DEF_OP_0:.*]] = pdl_interp.get_defining_op of %[[OPERAND_0]] + // CHECK-DAG: %[[RESULT_0:.*]] = pdl_interp.get_result 0 of %[[DEF_OP_0]] + // CHECK-DAG: pdl_interp.are_equal %[[RESULT_0]], %[[OPERAND_0]] + + // Get the second result and check it matches the second operand. + // CHECK-DAG: %[[OPERAND_1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]] + // CHECK-DAG: %[[DEF_OP_1:.*]] = pdl_interp.get_defining_op of %[[OPERAND_1]] + // CHECK-DAG: %[[RESULT_1:.*]] = pdl_interp.get_result 1 of %[[DEF_OP_1]] + // CHECK-DAG: pdl_interp.are_equal %[[RESULT_1]], %[[OPERAND_1]] + + // Check that the parent operation of both results is the same. + // CHECK-DAG: pdl_interp.are_equal %[[DEF_OP_0]], %[[DEF_OP_1]] + + pdl.pattern : benefit(1) { + %type1 = pdl.type : i32 + %type2 = pdl.type + %inputOp = pdl.operation -> %type1, %type2 + %result1 = pdl.result 0 of %inputOp + %result2 = pdl.result 1 of %inputOp + + %root = pdl.operation(%result1, %result2) pdl.rewrite %root with "rewriter" } } @@ -134,12 +168,12 @@ // CHECK: pdl_interp.switch_type %[[RESULT_TYPE]] to [i32, i64] pdl.pattern : benefit(1) { %type = pdl.type : i32 - %root, %result = pdl.operation -> %type + %root = pdl.operation -> %type pdl.rewrite %root with "rewriter" } pdl.pattern : benefit(1) { %type = pdl.type : i64 - %root, %result = pdl.operation -> %type + %root = pdl.operation -> %type pdl.rewrite %root with "rewriter" } } @@ -161,13 +195,13 @@ pdl.pattern : benefit(1) { %resultType = pdl.type pdl.apply_constraint "typeConstraint"[](%resultType : !pdl.type) - %root, %result = pdl.operation -> %resultType + %root = pdl.operation -> %resultType pdl.rewrite %root with "rewriter" } pdl.pattern : benefit(1) { %resultType = pdl.type - %apply, %applyRes = pdl.operation -> %resultType + %apply = pdl.operation -> %resultType pdl.rewrite %apply with "rewriter" } } diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir @@ -63,7 +63,8 @@ %root = pdl.operation "foo.op"(%operand) pdl.rewrite %root { %type = pdl.type : i32 - %newOp, %result = pdl.operation "foo.op"(%operand) -> %type + %newOp = pdl.operation "foo.op"(%operand) -> %type + %result = pdl.result 0 of %newOp %newOp1 = pdl.operation "foo.op2"(%result) pdl.erase %root } @@ -84,7 +85,8 @@ %root = pdl.operation "foo.op"(%operand) pdl.rewrite %root { %type = pdl.type : i32 - %newOp, %result = pdl.operation "foo.op"(%operand) -> %type + %newOp = pdl.operation "foo.op"(%operand) -> %type + %result = pdl.result 0 of %newOp %newOp1 = pdl.operation "foo.op2"(%result) pdl.erase %root } @@ -101,10 +103,10 @@ pdl.pattern : benefit(1) { %rootType = pdl.type %rootType1 = pdl.type - %root, %results:2 = pdl.operation "foo.op" -> %rootType, %rootType1 + %root = pdl.operation "foo.op" -> %rootType, %rootType1 pdl.rewrite %root { %newType1 = pdl.type - %newOp, %newResults:2 = pdl.operation "foo.op" -> %rootType, %newType1 + %newOp = pdl.operation "foo.op" -> %rootType, %newType1 pdl.replace %root with %newOp } } @@ -112,23 +114,6 @@ // ----- -// CHECK-LABEL: module @operation_result_types_infer_from_value_replacement -module @operation_result_types_infer_from_value_replacement { - // CHECK: module @rewriters - // CHECK: func @pdl_generated_rewriter(%[[TYPE:.*]]: !pdl.type - // CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]] - pdl.pattern : benefit(1) { - %rootType = pdl.type - %root, %result = pdl.operation "foo.op" -> %rootType - pdl.rewrite %root { - %newType = pdl.type - %newOp, %newResult = pdl.operation "foo.op" -> %newType - pdl.replace %root with (%newResult) - } - } -} -// ----- - // CHECK-LABEL: module @replace_with_op module @replace_with_op { // CHECK: module @rewriters @@ -138,9 +123,9 @@ // CHECK: pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]]) pdl.pattern : benefit(1) { %type = pdl.type : i32 - %root, %result = pdl.operation "foo.op" -> %type + %root = pdl.operation "foo.op" -> %type pdl.rewrite %root { - %newOp, %newResult = pdl.operation "foo.op" -> %type + %newOp = pdl.operation "foo.op" -> %type pdl.replace %root with %newOp } } @@ -157,9 +142,10 @@ // CHECK: pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]]) pdl.pattern : benefit(1) { %type = pdl.type : i32 - %root, %result = pdl.operation "foo.op" -> %type + %root = pdl.operation "foo.op" -> %type pdl.rewrite %root { - %newOp, %newResult = pdl.operation "foo.op" -> %type + %newOp = pdl.operation "foo.op" -> %type + %newResult = pdl.result 0 of %newOp pdl.replace %root with (%newResult) } } @@ -192,10 +178,10 @@ // CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]] pdl.pattern : benefit(1) { %type = pdl.type - %root, %result = pdl.operation "foo.op" -> %type + %root = pdl.operation "foo.op" -> %type pdl.rewrite %root { %newType = pdl.create_native "functor"[true](%root : !pdl.operation) : !pdl.type - %newOp, %newResult = pdl.operation "foo.op" -> %newType + %newOp = pdl.operation "foo.op" -> %newType pdl.replace %root with %newOp } } diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir --- a/mlir/test/Dialect/PDL/invalid.mlir +++ b/mlir/test/Dialect/PDL/invalid.mlir @@ -24,7 +24,7 @@ // expected-error@below {{expected only one of [`type`, `value`] to be set}} %attr = pdl.attribute : %type 10 - %op, %result = pdl.operation "foo.op" {"attr" = %attr} -> %type + %op = pdl.operation "foo.op" {"attr" = %attr} -> %type pdl.rewrite %op with "rewriter" } @@ -108,7 +108,7 @@ // expected-error@below {{op must have inferable or constrained result types when nested within `pdl.rewrite`}} // expected-note@below {{result type #0 was not constrained}} - %newOp, %result = pdl.operation "foo.op" -> %type + %newOp = pdl.operation "foo.op" -> %type } } @@ -147,28 +147,12 @@ // ----- -//===----------------------------------------------------------------------===// -// pdl::ReplaceOp -//===----------------------------------------------------------------------===// - -pdl.pattern : benefit(1) { - %root = pdl.operation "foo.op" - pdl.rewrite %root { - %type = pdl.type : i32 - %newOp, %newResult = pdl.operation "foo.op" -> %type - - // expected-error@below {{to have the same number of results as the replacement operation}} - pdl.replace %root with %newOp - } -} - -// ----- - pdl.pattern : benefit(1) { %type = pdl.type : i32 - %root, %oldResult = pdl.operation "foo.op" -> %type + %root = pdl.operation "foo.op" -> %type pdl.rewrite %root { - %newOp, %newResult = pdl.operation "foo.op" -> %type + %newOp = pdl.operation "foo.op" -> %type + %newResult = pdl.result 0 of %newOp // expected-error@below {{expected no replacement values to be provided when the replacement operation is present}} "pdl.replace"(%root, %newOp, %newResult) { @@ -179,19 +163,6 @@ // ----- -pdl.pattern : benefit(1) { - %root = pdl.operation "foo.op" - pdl.rewrite %root { - %type = pdl.type : i32 - %newOp, %newResult = pdl.operation "foo.op" -> %type - - // expected-error@below {{to have the same number of results as the provided replacement values}} - pdl.replace %root with (%newResult) - } -} - -// ----- - //===----------------------------------------------------------------------===// // pdl::RewriteOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir --- a/mlir/test/Dialect/PDL/ops.mlir +++ b/mlir/test/Dialect/PDL/ops.mlir @@ -8,7 +8,8 @@ // Operation with attributes and results. %attribute = pdl.attribute %type = pdl.type - %op0, %op0_result = pdl.operation {"attr" = %attribute} -> %type + %op0 = pdl.operation {"attr" = %attribute} -> %type + %op0_result = pdl.result 0 of %op0 // Operation with input. %input = pdl.operand @@ -46,38 +47,23 @@ pdl.pattern @infer_type_from_operation_replace : benefit(1) { %type1 = pdl.type : i32 %type2 = pdl.type - %root, %results:2 = pdl.operation -> %type1, %type2 + %root = pdl.operation -> %type1, %type2 pdl.rewrite %root { %type3 = pdl.type - %newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3 + %newOp = pdl.operation "foo.op" -> %type1, %type3 pdl.replace %root with %newOp } } // ----- -// Check that the result type of an operation within a rewrite can be inferred -// from a pdl.replace. -pdl.pattern @infer_type_from_result_replace : benefit(1) { - %type1 = pdl.type : i32 - %type2 = pdl.type - %root, %results:2 = pdl.operation -> %type1, %type2 - pdl.rewrite %root { - %type3 = pdl.type - %newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3 - pdl.replace %root with (%newResults#0, %newResults#1) - } -} - -// ----- - // Check that the result type of an operation within a rewrite can be inferred // from a pdl.replace. pdl.pattern @infer_type_from_type_used_in_match : benefit(1) { %type1 = pdl.type : i32 %type2 = pdl.type - %root, %results:2 = pdl.operation -> %type1, %type2 + %root = pdl.operation -> %type1, %type2 pdl.rewrite %root { - %newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type2 + %newOp = pdl.operation "foo.op" -> %type1, %type2 } }