diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -168,7 +168,7 @@ def PDLInterp_AreEqualOp : PDLInterp_PredicateOp<"are_equal", [NoSideEffect, SameTypeOperands]> { - let summary = "Check if two positional values are equivalent"; + let summary = "Check if two positional values or ranges are equivalent"; let description = [{ `pdl_interp.are_equal` operations compare two positional values for equality. On success, this operation branches to the true destination, @@ -241,19 +241,29 @@ let summary = "Check the number of operands of an `Operation`"; let description = [{ `pdl_interp.check_operand_count` operations compare the number of operands - of a given operation value with a constant. On success, this operation - branches to the true destination, otherwise the false destination is taken. + of a given operation value with a constant. The comparison is either exact + or at_least, with the latter used to compare against a minimum number of + expected operands. On success, this operation branches to the true + destination, otherwise the false destination is taken. Example: ```mlir + // Check for exact equality. pdl_interp.check_operand_count of %op is 2 -> ^matchDest, ^failureDest + + // Check for at least N operands. + pdl_interp.check_operand_count of %op is at_least 2 -> ^matchDest, ^failureDest ``` }]; let arguments = (ins PDL_Operation:$operation, - Confined:$count); - let assemblyFormat = "`of` $operation `is` $count attr-dict `->` successors"; + Confined:$count, + UnitAttr:$compareAtLeast); + let assemblyFormat = [{ + `of` $operation `is` (`at_least` $compareAtLeast^)? $count attr-dict + `->` successors + }]; } //===----------------------------------------------------------------------===// @@ -288,19 +298,29 @@ let summary = "Check the number of results of an `Operation`"; let description = [{ `pdl_interp.check_result_count` operations compare the number of results - of a given operation value with a constant. On success, this operation - branches to the true destination, otherwise the false destination is taken. + of a given operation value with a constant. The comparison is either exact + or at_least, with the latter used to compare against a minimum number of + expected results. On success, this operation branches to the true + destination, otherwise the false destination is taken. Example: ```mlir - pdl_interp.check_result_count of %op is 0 -> ^matchDest, ^failureDest + // Check for exact equality. + pdl_interp.check_result_count of %op is 2 -> ^matchDest, ^failureDest + + // Check for at least N results. + pdl_interp.check_result_count of %op is at_least 2 -> ^matchDest, ^failureDest ``` }]; let arguments = (ins PDL_Operation:$operation, - Confined:$count); - let assemblyFormat = "`of` $operation `is` $count attr-dict `->` successors"; + Confined:$count, + UnitAttr:$compareAtLeast); + let assemblyFormat = [{ + `of` $operation `is` (`at_least` $compareAtLeast^)? $count attr-dict + `->` successors + }]; } //===----------------------------------------------------------------------===// @@ -326,6 +346,30 @@ let assemblyFormat = "$value `is` $type attr-dict `->` successors"; } +//===----------------------------------------------------------------------===// +// pdl_interp::CheckTypesOp +//===----------------------------------------------------------------------===// + +def PDLInterp_CheckTypesOp + : PDLInterp_PredicateOp<"check_types", [NoSideEffect]> { + let summary = "Compare a range of types to a range of known values"; + let description = [{ + `pdl_interp.check_types` operations compare a range of types with a + statically known range of types. On success, this operation branches + to the true destination, otherwise the false destination is taken. + + Example: + + ```mlir + pdl_interp.check_types %type are [i32, i64] -> ^matchDest, ^failureDest + ``` + }]; + + let arguments = (ins PDL_RangeOf:$value, + TypeArrayAttr:$types); + let assemblyFormat = "$value `are` $types attr-dict `->` successors"; +} + //===----------------------------------------------------------------------===// // pdl_interp::CreateAttributeOp //===----------------------------------------------------------------------===// @@ -363,21 +407,23 @@ let summary = "Create an instance of a specific `Operation`"; let description = [{ `pdl_interp.create_operation` operations create an `Operation` instance with - the specified attributes, operands, and result types. + the specified attributes, operands, and result types. See `pdl.operation` + for a more detailed description on the interpretation of the arguments to + this operation. Example: ```mlir // Create an instance of a `foo.op` operation. - %op = pdl_interp.create_operation "foo.op"(%arg0) {"attrA" = %attr0} -> %type, %type + %op = pdl_interp.create_operation "foo.op"(%arg0 : !pdl.value) {"attrA" = %attr0} -> (%type : !pdl.type) ``` }]; let arguments = (ins StrAttr:$name, - Variadic:$operands, + Variadic>:$operands, Variadic:$attributes, StrArrayAttr:$attributeNames, - Variadic:$types); + Variadic>:$types); let results = (outs PDL_Operation:$operation); let builders = [ @@ -386,9 +432,13 @@ "ArrayAttr":$attributeNames), [{ build($_builder, $_state, $_builder.getType(), name, operands, attributes, attributeNames, types); - }]>]; - let parser = [{ return ::parseCreateOperationOp(parser, result); }]; - let printer = [{ ::print(p, *this); }]; + }]> + ]; + let assemblyFormat = [{ + $name (`(` $operands^ `:` type($operands) `)`)? + custom($attributes, $attributeNames) + (`->` `(` $types^ `:` type($types) `)`)? attr-dict + }]; } //===----------------------------------------------------------------------===// @@ -419,6 +469,28 @@ ]; } +//===----------------------------------------------------------------------===// +// pdl_interp::CreateTypesOp +//===----------------------------------------------------------------------===// + +def PDLInterp_CreateTypesOp : PDLInterp_Op<"create_types", [NoSideEffect]> { + let summary = "Create an interpreter handle to a range of constant `Type`s"; + let description = [{ + `pdl_interp.create_types` operations generate a handle within the + interpreter for a specific range of constant type values. + + Example: + + ```mlir + pdl_interp.create_types [i64, i64] + ``` + }]; + + let arguments = (ins TypeArrayAttr:$value); + let results = (outs PDL_RangeOf:$result); + let assemblyFormat = "$value attr-dict"; +} + //===----------------------------------------------------------------------===// // pdl_interp::EraseOp //===----------------------------------------------------------------------===// @@ -523,19 +595,20 @@ let summary = "Get the defining operation of a `Value`"; let description = [{ `pdl_interp.get_defining_op` operations try to get the defining operation - of a specific value. If the value is not an operation result, null is - returned. + of a specific value or range of values. In the case of range, the defining + op of the first value is returned. If the value is not an operation result + or range of operand results, null is returned. Example: ```mlir - %op = pdl_interp.get_defining_op of %value + %op = pdl_interp.get_defining_op of %value : !pdl.value ``` }]; - let arguments = (ins PDL_Value:$value); + let arguments = (ins PDL_InstOrRangeOf:$value); let results = (outs PDL_Operation:$operation); - let assemblyFormat = "`of` $value attr-dict"; + let assemblyFormat = "`of` $value `:` type($value) attr-dict"; } //===----------------------------------------------------------------------===// @@ -562,6 +635,49 @@ let assemblyFormat = "$index `of` $operation attr-dict"; } +//===----------------------------------------------------------------------===// +// pdl_interp::GetOperandsOp +//===----------------------------------------------------------------------===// + +def PDLInterp_GetOperandsOp : PDLInterp_Op<"get_operands", [NoSideEffect]> { + let summary = "Get a specified operand group from an `Operation`"; + let description = [{ + `pdl_interp.get_operand_group` operations try to get a specific operand + group from an operation. If the expected result is a single Value, null is + returned if the operand group is not of size 1. If a range is expected, + null is returned if the operand group is invalid. If no index is provided, + the returned operand group corresponds to all operands of the operation. + + Example: + + ```mlir + // Get the first group of operands from an operation, and expect a single + // element. + %operand = pdl_interp.get_operands 0 of %op : !pdl.value + + // Get the first group of operands from an operation. + %operands = pdl_interp.get_operands 0 of %op : !pdl.range + + // Get all of the operands from an operation. + %operands = pdl_interp.get_operands of %op : !pdl.range + ``` + }]; + + let arguments = (ins + PDL_Operation:$operation, + OptionalAttr>:$index + ); + let results = (outs PDL_InstOrRangeOf:$value); + let assemblyFormat = "($index^)? `of` $operation `:` type($value) attr-dict"; + let builders = [ + OpBuilderDAG<(ins "Type":$resultType, "Value":$operation, + "Optional":$index), [{ + build($_builder, $_state, resultType, operation, + index ? $_builder.getI32IntegerAttr(*index) : IntegerAttr()); + }]>, + ]; +} + //===----------------------------------------------------------------------===// // pdl_interp::GetResultOp //===----------------------------------------------------------------------===// @@ -586,59 +702,117 @@ let assemblyFormat = "$index `of` $operation attr-dict"; } +//===----------------------------------------------------------------------===// +// pdl_interp::GetResultsOp +//===----------------------------------------------------------------------===// + +def PDLInterp_GetResultsOp : PDLInterp_Op<"get_results", [NoSideEffect]> { + let summary = "Get a specified result group from an `Operation`"; + let description = [{ + `pdl_interp.get_result_group` operations try to get a specific result group + from an operation. If the expected result is a single Value, null is + returned if the result group is not of size 1. If a range is expected, + null is returned if the result group is invalid. If no index is provided, + the returned operand group corresponds to all results of the operation. + + Example: + + ```mlir + // Get the first group of results from an operation, and expect a single + // element. + %result = pdl_interp.get_results 0 of %op : !pdl.value + + // Get the first group of results from an operation. + %results = pdl_interp.get_results 0 of %op : !pdl.range + + // Get all of the results from an operation. + %results = pdl_interp.get_results of %op : !pdl.range + ``` + }]; + + let arguments = (ins + PDL_Operation:$operation, + OptionalAttr>:$index + ); + let results = (outs PDL_InstOrRangeOf:$value); + let assemblyFormat = "($index^)? `of` $operation `:` type($value) attr-dict"; + let builders = [ + OpBuilderDAG<(ins "Type":$resultType, "Value":$operation, + "Optional":$index), [{ + build($_builder, $_state, resultType, operation, + index ? $_builder.getI32IntegerAttr(*index) : IntegerAttr()); + }]>, + OpBuilderDAG<(ins "Value":$operation), [{ + build($_builder, $_state, + pdl::RangeType::get($_builder.getType()), operation, + IntegerAttr()); + }]>, + ]; +} + //===----------------------------------------------------------------------===// // pdl_interp::GetValueTypeOp //===----------------------------------------------------------------------===// -// Get a type from the root operation, held in the rewriter context. -def PDLInterp_GetValueTypeOp : PDLInterp_Op<"get_value_type", [NoSideEffect]> { +def PDLInterp_GetValueTypeOp : PDLInterp_Op<"get_value_type", [NoSideEffect, + TypesMatchWith<"`value` type matches arity of `result`", + "result", "value", "getGetValueTypeOpValueType($_self)">]> { let summary = "Get the result type of a specified `Value`"; let description = [{ `pdl_interp.get_value_type` operations get the resulting type of a specific - value. + value or range thereof. Example: ```mlir - %type = pdl_interp.get_value_type of %value + // Get the type of a single value. + %type = pdl_interp.get_value_type of %value : !pdl.type + + // Get the types of a value range. + %type = pdl_interp.get_value_type of %values : !pdl.range ``` }]; - let arguments = (ins PDL_Value:$value); - let results = (outs PDL_Type:$result); - let assemblyFormat = "`of` $value attr-dict"; + let arguments = (ins PDL_InstOrRangeOf:$value); + let results = (outs PDL_InstOrRangeOf:$result); + let assemblyFormat = "`of` $value `:` type($result) attr-dict"; let builders = [ OpBuilderDAG<(ins "Value":$value), [{ - build($_builder, $_state, $_builder.getType(), value); + Type valType = value.getType(); + Type typeType = $_builder.getType(); + build($_builder, $_state, + valType.isa() ? pdl::RangeType::get(typeType) + : typeType, + value); }]> ]; } //===----------------------------------------------------------------------===// -// pdl_interp::InferredTypeOp +// pdl_interp::InferredTypesOp //===----------------------------------------------------------------------===// -def PDLInterp_InferredTypeOp : PDLInterp_Op<"inferred_type"> { - let summary = "Generate a handle to a Type that is \"inferred\""; +def PDLInterp_InferredTypesOp : PDLInterp_Op<"inferred_types"> { + let summary = "Generate a handle to a range of Types that are \"inferred\""; let description = [{ - `pdl_interp.inferred_type` operations generate a handle to a type that - should be inferred. This signals to other operations, such as - `pdl_interp.create_operation`, that this type should be inferred. + `pdl_interp.inferred_types` operations generate handles to ranges of types + that should be inferred. This signals to other operations, such as + `pdl_interp.create_operation`, that these types should be inferred. Example: ```mlir - pdl_interp.inferred_type + %types = pdl_interp.inferred_types ``` }]; - let results = (outs PDL_Type:$type); + let results = (outs PDL_RangeOf:$type); let assemblyFormat = "attr-dict"; - let builders = [ OpBuilderDAG<(ins), [{ - build($_builder, $_state, $_builder.getType()); - }]>, + build($_builder, $_state, + pdl::RangeType::get($_builder.getType())); + }]> ]; } @@ -650,7 +824,8 @@ : PDLInterp_PredicateOp<"is_not_null", [NoSideEffect]> { let summary = "Check if a positional value is non-null"; let description = [{ - `pdl_interp.is_not_null` operations check that a positional value exists. On + `pdl_interp.is_not_null` operations check that a positional value or range + exists. For ranges, this does not mean that the range was simply empty. On success, this operation branches to the true destination. Otherwise, the false destination is taken. @@ -718,12 +893,15 @@ ```mlir // Replace root node with 2 values: - pdl_interp.replace %root with (%val0, %val1) + pdl_interp.replace %root with (%val0, %val1 : !pdl.type, !pdl.type) ``` }]; let arguments = (ins PDL_Operation:$operation, - Variadic:$replValues); - let assemblyFormat = "$operation `with` `(` $replValues `)` attr-dict"; + Variadic>:$replValues); + let assemblyFormat = [{ + $operation `with` ` ` `(` ($replValues^ `:` type($replValues))? `)` + attr-dict + }]; } //===----------------------------------------------------------------------===// @@ -886,9 +1064,9 @@ }]; let builders = [ - OpBuilderDAG<(ins "Value":$edge, "TypeRange":$types, "Block *":$defaultDest, - "BlockRange":$dests), [{ - build($_builder, $_state, edge, $_builder.getTypeArrayAttr(types), + OpBuilderDAG<(ins "Value":$edge, "ArrayRef":$types, + "Block *":$defaultDest, "BlockRange":$dests), [{ + build($_builder, $_state, edge, $_builder.getArrayAttr(types), defaultDest, dests); }]>, ]; @@ -898,4 +1076,45 @@ }]; } +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchTypesOp +//===----------------------------------------------------------------------===// + +def PDLInterp_SwitchTypesOp : PDLInterp_SwitchOp<"switch_types", + [NoSideEffect]> { + let summary = "Switch on a range of `Type` values"; + let description = [{ + `pdl_interp.switch_types` operations compare a range of types with a set of + statically known ranges. If the value matches one of the provided case + values the destination for that case value is taken, otherwise the default + destination is taken. + + Example: + + ```mlir + pdl_interp.switch_types %type is [[i32], [i64, i64]] -> ^i32Dest, ^i64Dest, ^defaultDest + ``` + }]; + + let arguments = (ins + PDL_RangeOf:$value, + TypedArrayAttrBase:$caseValues + ); + let assemblyFormat = [{ + $value `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest + }]; + + let builders = [ + OpBuilderDAG<(ins "Value":$edge, "ArrayRef":$types, + "Block *":$defaultDest, "BlockRange":$dests), [{ + build($_builder, $_state, edge, $_builder.getArrayAttr(types), + defaultDest, dests); + }]>, + ]; + + let extraClassDeclaration = [{ + auto getCaseTypes() { return caseValues().getAsRange(); } + }]; +} + #endif // MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1433,7 +1433,7 @@ CPred<"$_self.isa<::mlir::ArrayAttr>()">, // Guarantee all elements satisfy the constraints from `element` Concat<"::llvm::all_of($_self.cast<::mlir::ArrayAttr>(), " - "[](::mlir::Attribute attr) { return ", + "[&](::mlir::Attribute attr) { return ", SubstLeaves<"$_self", "attr", element.predicate>, "; })">]>, summary> { diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -56,9 +56,8 @@ /// Create an interpreter switch predicate operation, with a provided default /// and several case destinations. - void generateSwitch(Block *currentBlock, Qualifier *question, Value val, - Block *defaultDest, - ArrayRef> dests); + void generateSwitch(SwitchNode *switchNode, Block *currentBlock, + Qualifier *question, Value val, Block *defaultDest); /// Create the interpreter operations to record a successful pattern match. void generateRecordMatch(Block *currentBlock, Block *nextBlock, @@ -88,9 +87,15 @@ void generateRewriter(pdl::ResultOp resultOp, DenseMap &rewriteValues, function_ref mapRewriteValue); + void generateRewriter(pdl::ResultsOp resultOp, + DenseMap &rewriteValues, + function_ref mapRewriteValue); void generateRewriter(pdl::TypeOp typeOp, DenseMap &rewriteValues, function_ref mapRewriteValue); + void generateRewriter(pdl::TypesOp typeOp, + DenseMap &rewriteValues, + function_ref mapRewriteValue); /// Generate the values used for resolving the result types of an operation /// created within a dag rewriter region. @@ -200,12 +205,7 @@ // Generate code for a switch node. } else if (auto *switchNode = dyn_cast(&node)) { - // Collect the next blocks for all of the children and generate a switch. - llvm::MapVector children; - for (auto &it : switchNode->getChildren()) - children.insert({it.first, generateMatcher(*it.second)}); - generateSwitch(block, node.getQuestion(), val, nextBlock, - children.takeVector()); + generateSwitch(switchNode, block, node.getQuestion(), val, nextBlock); // Generate code for a success node. } else if (auto *successNode = dyn_cast(&node)) { @@ -242,6 +242,14 @@ operandPos->getOperandNumber()); break; } + case Predicates::OperandGroupPos: { + auto *operandPos = cast(pos); + Type valueTy = builder.getType(); + value = builder.create( + loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, + parentVal, operandPos->getOperandGroupNumber()); + break; + } case Predicates::AttributePos: { auto *attrPos = cast(pos); value = builder.create( @@ -250,10 +258,10 @@ break; } case Predicates::TypePos: { - if (parentVal.getType().isa()) - value = builder.create(loc, parentVal); - else + if (parentVal.getType().isa()) value = builder.create(loc, parentVal); + else + value = builder.create(loc, parentVal); break; } case Predicates::ResultPos: { @@ -263,6 +271,14 @@ resPos->getResultNumber()); break; } + case Predicates::ResultGroupPos: { + auto *resPos = cast(pos); + Type valueTy = builder.getType(); + value = builder.create( + loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, + parentVal, resPos->getResultGroupNumber()); + break; + } default: llvm_unreachable("Generating unknown Position getter"); break; @@ -277,7 +293,8 @@ Block *falseDest) { builder.setInsertionPointToEnd(currentBlock); Location loc = val.getLoc(); - switch (question->getKind()) { + Predicates::Kind kind = question->getKind(); + switch (kind) { case Predicates::IsNotNullQuestion: builder.create(loc, val, trueDest, falseDest); break; @@ -289,8 +306,12 @@ } case Predicates::TypeQuestion: { auto *ans = cast(answer); - builder.create( - loc, val, TypeAttr::get(ans->getValue()), trueDest, falseDest); + if (val.getType().isa()) + builder.create( + loc, val, ans->getValue().cast(), trueDest, falseDest); + else + builder.create( + loc, val, ans->getValue().cast(), trueDest, falseDest); break; } case Predicates::AttributeQuestion: { @@ -299,18 +320,20 @@ trueDest, falseDest); break; } - case Predicates::OperandCountQuestion: { - auto *unsignedAnswer = cast(answer); + case Predicates::OperandCountAtLeastQuestion: + case Predicates::OperandCountQuestion: builder.create( - loc, val, unsignedAnswer->getValue(), trueDest, falseDest); + loc, val, cast(answer)->getValue(), + /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, + trueDest, falseDest); break; - } - case Predicates::ResultCountQuestion: { - auto *unsignedAnswer = cast(answer); + case Predicates::ResultCountAtLeastQuestion: + case Predicates::ResultCountQuestion: builder.create( - loc, val, unsignedAnswer->getValue(), trueDest, falseDest); + loc, val, cast(answer)->getValue(), + /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, + trueDest, falseDest); break; - } case Predicates::EqualToQuestion: { auto *equalToQuestion = cast(question); builder.create( @@ -336,7 +359,7 @@ template static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, - ArrayRef> dests) { + llvm::MapVector &dests) { std::vector values; std::vector blocks; values.reserve(dests.size()); @@ -348,27 +371,83 @@ builder.create(val.getLoc(), val, values, defaultDest, blocks); } -void PatternLowering::generateSwitch( - Block *currentBlock, Qualifier *question, Value val, Block *defaultDest, - ArrayRef> dests) { +void PatternLowering::generateSwitch(SwitchNode *switchNode, + Block *currentBlock, Qualifier *question, + Value val, Block *defaultDest) { + // If the switch question is not an exact answer, i.e. for the `at_least` + // cases, we generate a special block sequence. + Predicates::Kind kind = question->getKind(); + if (kind == Predicates::OperandCountAtLeastQuestion || + kind == Predicates::ResultCountAtLeastQuestion) { + // Order the children such that the cases are in reverse numerical order. + SmallVector sortedChildren( + llvm::seq(0, switchNode->getChildren().size())); + llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) { + return cast(switchNode->getChild(lhs).first)->getValue() > + cast(switchNode->getChild(rhs).first)->getValue(); + }); + + // Build the destination for each child using the next highest child as a + // a failure destination. This essentially creates the following control + // flow: + // + // if (operand_count < 1) + // goto failure + // if (child1.match()) + // ... + // + // if (operand_count < 2) + // goto failure + // if (child2.match()) + // ... + // + // failure: + // ... + // + failureBlockStack.push_back(defaultDest); + for (unsigned idx : sortedChildren) { + auto &child = switchNode->getChild(idx); + Block *childBlock = generateMatcher(*child.second); + Block *predicateBlock = builder.createBlock(childBlock); + generatePredicate(predicateBlock, question, child.first, val, childBlock, + defaultDest); + failureBlockStack.back() = predicateBlock; + } + Block *firstPredicateBlock = failureBlockStack.pop_back_val(); + currentBlock->getOperations().splice(currentBlock->end(), + firstPredicateBlock->getOperations()); + firstPredicateBlock->erase(); + return; + } + + // Otherwise, generate each of the children and generate an interpreter + // switch. + llvm::MapVector children; + for (auto &it : switchNode->getChildren()) + children.insert({it.first, generateMatcher(*it.second)}); builder.setInsertionPointToEnd(currentBlock); + switch (question->getKind()) { case Predicates::OperandCountQuestion: return createSwitchOp(val, defaultDest, builder, dests); + int32_t>(val, defaultDest, builder, children); case Predicates::ResultCountQuestion: return createSwitchOp(val, defaultDest, builder, dests); + int32_t>(val, defaultDest, builder, children); case Predicates::OperationNameQuestion: return createSwitchOp(val, defaultDest, builder, - dests); + children); case Predicates::TypeQuestion: + if (val.getType().isa()) { + return createSwitchOp( + val, defaultDest, builder, children); + } return createSwitchOp( - val, defaultDest, builder, dests); + val, defaultDest, builder, children); case Predicates::AttributeQuestion: return createSwitchOp( - val, defaultDest, builder, dests); + val, defaultDest, builder, children); default: llvm_unreachable("Generating unknown switch predicate."); } @@ -436,6 +515,11 @@ return newValue = builder.create( typeOp.getLoc(), type); } + } else if (pdl::TypesOp typeOp = dyn_cast(oldOp)) { + if (ArrayAttr type = typeOp.typesAttr()) { + return newValue = builder.create( + typeOp.getLoc(), typeOp.getType(), type); + } } // Otherwise, add this as an input to the rewriter. @@ -460,10 +544,10 @@ for (Operation &rewriteOp : *rewriter.getBody()) { llvm::TypeSwitch(&rewriteOp) .Case( - [&](auto op) { - this->generateRewriter(op, rewriteValues, mapRewriteValue); - }); + pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp, + pdl::TypeOp, pdl::TypesOp>([&](auto op) { + this->generateRewriter(op, rewriteValues, mapRewriteValue); + }); } } @@ -529,14 +613,39 @@ rewriteValues[operationOp.op()] = createdOp; // Generate accesses for any results that have their types constrained. - for (auto it : llvm::enumerate(operationOp.types())) { + // Handle the case where there is a single range representing all of the + // result types. + OperandRange resultTys = operationOp.types(); + if (resultTys.size() == 1 && resultTys[0].getType().isa()) { + Value &type = rewriteValues[resultTys[0]]; + if (!type) { + auto results = builder.create(loc, createdOp); + type = builder.create(loc, results); + } + return; + } + + // Otherwise, populate the individual results. + bool seenVariableLength = false; + Type valueTy = builder.getType(); + Type valueRangeTy = pdl::RangeType::get(valueTy); + for (auto it : llvm::enumerate(resultTys)) { Value &type = rewriteValues[it.value()]; if (type) continue; - - Value getResultVal = builder.create( - loc, builder.getType(), createdOp, it.index()); - type = builder.create(loc, getResultVal); + bool isVariadic = it.value().getType().isa(); + seenVariableLength |= isVariadic; + + // After a variable length result has been seen, we need to use result + // groups because the exact index of the result is not statically known. + Value resultVal; + if (seenVariableLength) + resultVal = builder.create( + loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); + else + resultVal = builder.create( + loc, valueTy, createdOp, it.index()); + type = builder.create(loc, resultVal); } } @@ -549,11 +658,12 @@ // for using an operation for simplicitly, but the interpreter isn't as // user facing. if (Value replOp = replaceOp.replOperation()) { - pdl::OperationOp op = cast(replOp.getDefiningOp()); - for (unsigned i = 0, e = op.types().size(); i < e; ++i) - replOperands.push_back(builder.create( - replOp.getLoc(), builder.getType(), - mapRewriteValue(replOp), i)); + // Don't use replace if we know the replaced operation has no results. + auto opOp = replaceOp.operation().getDefiningOp(); + if (!opOp || !opOp.types().empty()) { + replOperands.push_back(builder.create( + replOp.getLoc(), mapRewriteValue(replOp))); + } } else { for (Value operand : replaceOp.replValues()) replOperands.push_back(mapRewriteValue(operand)); @@ -578,15 +688,33 @@ mapRewriteValue(resultOp.parent()), resultOp.index()); } +void PatternLowering::generateRewriter( + pdl::ResultsOp resultOp, DenseMap &rewriteValues, + function_ref mapRewriteValue) { + rewriteValues[resultOp] = builder.create( + resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()), + resultOp.index()); +} + void PatternLowering::generateRewriter( pdl::TypeOp typeOp, DenseMap &rewriteValues, function_ref mapRewriteValue) { // If the type isn't constant, the users (e.g. OperationOp) will resolve this // type. if (TypeAttr typeAttr = typeOp.typeAttr()) { - Value newType = + rewriteValues[typeOp] = builder.create(typeOp.getLoc(), typeAttr); - rewriteValues[typeOp] = newType; + } +} + +void PatternLowering::generateRewriter( + pdl::TypesOp typeOp, DenseMap &rewriteValues, + function_ref mapRewriteValue) { + // If the type isn't constant, the users (e.g. OperationOp) will resolve this + // type. + if (ArrayAttr typeAttr = typeOp.typesAttr()) { + rewriteValues[typeOp] = builder.create( + typeOp.getLoc(), typeOp.getType(), typeAttr); } } @@ -594,28 +722,38 @@ pdl::OperationOp op, SmallVectorImpl &types, DenseMap &rewriteValues, function_ref mapRewriteValue) { - // Functor that returns if the given use can be used to infer a type. + // Look for an operation that was replaced by `op`. The result types will be + // inferred from the results that were replaced. Block *rewriterBlock = op->getBlock(); - auto getReplacedOperationFrom = [&](OpOperand &use) -> Operation * { + Value replacedOp; + for (OpOperand &use : op.op().getUses()) { // Check that the use corresponds to a ReplaceOp and that it is the // replacement value, not the operation being replaced. pdl::ReplaceOp replOpUser = dyn_cast(use.getOwner()); if (!replOpUser || use.getOperandNumber() == 0) - return nullptr; + continue; // Make sure the replaced operation was defined before this one. - Operation *replacedOp = replOpUser.operation().getDefiningOp(); - if (replacedOp->getBlock() != rewriterBlock || - replacedOp->isBeforeInBlock(op)) - return replacedOp; - return nullptr; - }; + Value replOpVal = replOpUser.operation(); + Operation *replacedOp = replOpVal.getDefiningOp(); + if (replacedOp->getBlock() == rewriterBlock && + !replacedOp->isBeforeInBlock(op)) + continue; + + Value replacedOpResults = builder.create( + replacedOp->getLoc(), mapRewriteValue(replOpVal)); + types.push_back(builder.create( + replacedOp->getLoc(), replacedOpResults)); + return; + } + + // Check if the operation has type inference support. + if (op.hasTypeInference()) { + types.push_back(builder.create(op.getLoc())); + return; + } - // If non-None/non-Null, this is an operation that is replaced by `op`. - // If Null, there is no full replacement operation for `op`. - // If None, a replacement operation hasn't been searched for. - Optional fullReplacedOperation; - bool hasTypeInference = op.hasTypeInference(); - auto resultTypeValues = op.types(); + // Otherwise, handle inference for each of the result types individually. + OperandRange resultTypeValues = op.types(); types.reserve(resultTypeValues.size()); for (auto it : llvm::enumerate(resultTypeValues)) { Value resultType = it.value(); @@ -632,30 +770,11 @@ continue; } - // Check if the operation has type inference support. - if (hasTypeInference) { - types.push_back(builder.create(op.getLoc())); - continue; - } - - // Look for an operation that was replaced by `op`. The result type will be - // inferred from the result that was replaced. There is guaranteed to be a - // replacement for either the op, or this specific result. Note that this is - // guaranteed by the verifier of `pdl::OperationOp`. - Operation *replacedOp = nullptr; - if (!fullReplacedOperation.hasValue()) { - for (OpOperand &use : op.op().getUses()) - if ((replacedOp = getReplacedOperationFrom(use))) - break; - fullReplacedOperation = replacedOp; - assert(fullReplacedOperation && - "expected replaced op to infer a result type from"); - } else { - replacedOp = fullReplacedOperation.getValue(); - } - - auto replOpOp = cast(replacedOp); - types.push_back(mapRewriteValue(replOpOp.types()[it.index()])); + // The verifier asserts that the result types of each pdl.operation can be + // inferred. If we reach here, there is a bug either in the logic above or + // in the verifier for pdl.operation. + op->emitOpError() << "unable to infer result type for operation"; + llvm_unreachable("unable to infer result type for operation"); } } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -45,8 +45,10 @@ /// Positions, ordered by decreasing priority. OperationPos, OperandPos, + OperandGroupPos, AttributePos, ResultPos, + ResultGroupPos, TypePos, // Questions, ordered by dependency and decreasing priority. @@ -54,7 +56,9 @@ OperationNameQuestion, TypeQuestion, AttributeQuestion, + OperandCountAtLeastQuestion, OperandCountQuestion, + ResultCountAtLeastQuestion, ResultCountQuestion, EqualToQuestion, ConstraintQuestion, @@ -129,21 +133,15 @@ /// predicates, and assists generating bytecode and memory management. /// /// Operation positions form the base of other positions, which are formed -/// relative to a parent operation, e.g. OperandPosition<[0] -> 1>. Operations -/// are indexed by child index: [0, 1, 2] refers to the 3rd child of the 2nd -/// child of the root operation. -/// -/// Positions are linked to their parent position, which describes how to obtain -/// a positional value. As a concrete example, getting OperationPosition<[0, 1]> -/// would be `root->getOperand(1)->getDefiningOp()`, so its parent is -/// OperandPosition<[0] -> 1>, whose parent is OperationPosition<[0]>. +/// relative to a parent operation. Operations are anchored at Operand nodes, +/// except for the root operation which is parentless. class Position : public StorageUniquer::BaseStorage { public: explicit Position(Predicates::Kind kind) : kind(kind) {} virtual ~Position(); - /// Returns the base node position. This is an array of indices. - virtual ArrayRef getIndex() const = 0; + /// Returns the depth of the first ancestor operation position. + unsigned getOperationDepth() const; /// Returns the parent position. The root operation position has no parent. Position *getParent() const { return parent; } @@ -170,9 +168,6 @@ Predicates::AttributePos> { explicit AttributePosition(const KeyTy &key); - /// Returns the index of this position. - ArrayRef getIndex() const final { return parent->getIndex(); } - /// Returns the attribute name of this position. Identifier getName() const { return key.second; } }; @@ -187,42 +182,61 @@ Predicates::OperandPos> { explicit OperandPosition(const KeyTy &key); - /// Returns the index of this position. - ArrayRef getIndex() const final { return parent->getIndex(); } - /// Returns the operand number of this position. unsigned getOperandNumber() const { return key.second; } }; +//===----------------------------------------------------------------------===// +// OperandGroupPosition + +/// A position describing an operand group of an operation. +struct OperandGroupPosition + : public PredicateBase< + OperandGroupPosition, Position, + std::tuple, bool>, + Predicates::OperandGroupPos> { + explicit OperandGroupPosition(const KeyTy &key); + + /// Returns a hash suitable for the given keytype. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + /// Returns the group number of this position. If None, this group refers to + /// all operands. + Optional getOperandGroupNumber() const { return std::get<1>(key); } + + /// Returns if the operand group has unknown size. If false, the operand group + /// has at max one element. + bool isVariadic() const { return std::get<2>(key); } +}; + //===----------------------------------------------------------------------===// // OperationPosition /// An operation position describes an operation node in the IR. Other position /// kinds are formed with respect to an operation position. -struct OperationPosition - : public PredicateBase, - Predicates::OperationPos> { - using Base::Base; +struct OperationPosition : public PredicateBase, + Predicates::OperationPos> { + explicit OperationPosition(const KeyTy &key) : Base(key) { + parent = key.first; + } - /// Gets the root position, which is always [0]. + /// Gets the root position. static OperationPosition *getRoot(StorageUniquer &uniquer) { - return get(uniquer, ArrayRef(0)); + return Base::get(uniquer, nullptr, 0); } - /// Gets a node position for the given index. - static OperationPosition *get(StorageUniquer &uniquer, - ArrayRef index); - - /// Constructs an instance with the given storage allocator. - static OperationPosition *construct(StorageUniquer::StorageAllocator &alloc, - ArrayRef key) { - return Base::construct(alloc, alloc.copyInto(key)); + /// Gets an operation position with the given parent. + static OperationPosition *get(StorageUniquer &uniquer, Position *parent) { + return Base::get(uniquer, parent, parent->getOperationDepth() + 1); } - /// Returns the index of this position. - ArrayRef getIndex() const final { return key; } + /// Returns the depth of this position. + unsigned getDepth() const { return key.second; } /// Returns if this operation position corresponds to the root. - bool isRoot() const { return key.size() == 1 && key[0] == 0; } + bool isRoot() const { return getDepth() == 0; } }; //===----------------------------------------------------------------------===// @@ -235,13 +249,37 @@ Predicates::ResultPos> { explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; } - /// Returns the index of this position. - ArrayRef getIndex() const final { return key.first->getIndex(); } - /// Returns the result number of this position. unsigned getResultNumber() const { return key.second; } }; +//===----------------------------------------------------------------------===// +// ResultGroupPosition + +/// A position describing a result group of an operation. +struct ResultGroupPosition + : public PredicateBase< + ResultGroupPosition, Position, + std::tuple, bool>, + Predicates::ResultGroupPos> { + explicit ResultGroupPosition(const KeyTy &key) : Base(key) { + parent = std::get<0>(key); + } + + /// Returns a hash suitable for the given keytype. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + /// Returns the group number of this position. If None, this group refers to + /// all results. + Optional getResultGroupNumber() const { return std::get<1>(key); } + + /// Returns if the result group has unknown size. If false, the result group + /// has at max one element. + bool isVariadic() const { return std::get<2>(key); } +}; + //===----------------------------------------------------------------------===// // TypePosition @@ -250,14 +288,11 @@ struct TypePosition : public PredicateBase { explicit TypePosition(const KeyTy &key) : Base(key) { - assert((isa(key) || isa(key) || - isa(key)) && + assert((isa(key)) && "expected parent to be an attribute, operand, or result"); parent = key; } - - /// Returns the index of this position. - ArrayRef getIndex() const final { return key->getIndex(); } }; //===----------------------------------------------------------------------===// @@ -311,8 +346,9 @@ using Base::Base; }; -/// An Answer representing a `Type` value. -struct TypeAnswer : public PredicateBase { using Base::Base; }; @@ -365,6 +401,9 @@ struct OperandCountQuestion : public PredicateBase {}; +struct OperandCountAtLeastQuestion + : public PredicateBase {}; /// Compare the name of an operation with a known value. struct OperationNameQuestion @@ -375,6 +414,9 @@ struct ResultCountQuestion : public PredicateBase {}; +struct ResultCountAtLeastQuestion + : public PredicateBase {}; /// Compare the type of an attribute or value with a known type. struct TypeQuestion : public PredicateBase(); registerParametricStorageType(); + registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); + registerParametricStorageType(); registerParametricStorageType(); // Register the types of Questions with the uniquer. @@ -409,8 +453,10 @@ registerSingletonStorageType(); registerSingletonStorageType(); registerSingletonStorageType(); + registerSingletonStorageType(); registerSingletonStorageType(); registerSingletonStorageType(); + registerSingletonStorageType(); registerSingletonStorageType(); } }; @@ -433,10 +479,10 @@ Position *getRoot() { return OperationPosition::getRoot(uniquer); } /// Returns the parent position defining the value held by the given operand. - OperationPosition *getParent(OperandPosition *p) { - std::vector index = p->getIndex(); - index.push_back(p->getOperandNumber()); - return OperationPosition::get(uniquer, index); + OperationPosition *getOperandDefiningOp(Position *p) { + assert((isa(p)) && + "expected operand position"); + return OperationPosition::get(uniquer, p); } /// Returns an attribute position for an attribute of the given operation. @@ -449,11 +495,29 @@ return OperandPosition::get(uniquer, p, operand); } + /// Returns a position for a group of operands of the given operation. + Position *getOperandGroup(OperationPosition *p, Optional group, + bool isVariadic) { + return OperandGroupPosition::get(uniquer, p, group, isVariadic); + } + Position *getAllOperands(OperationPosition *p) { + return getOperandGroup(p, /*group=*/llvm::None, /*isVariadic=*/true); + } + /// Returns a result position for a result of the given operation. Position *getResult(OperationPosition *p, unsigned result) { return ResultPosition::get(uniquer, p, result); } + /// Returns a position for a group of results of the given operation. + Position *getResultGroup(OperationPosition *p, Optional group, + bool isVariadic) { + return ResultGroupPosition::get(uniquer, p, group, isVariadic); + } + Position *getAllResults(OperationPosition *p) { + return getResultGroup(p, /*group=*/llvm::None, /*isVariadic=*/true); + } + /// Returns a type position for the given entity. Position *getType(Position *p) { return TypePosition::get(uniquer, p); } @@ -496,6 +560,10 @@ return {OperandCountQuestion::get(uniquer), UnsignedAnswer::get(uniquer, count)}; } + Predicate getOperandCountAtLeast(unsigned count) { + return {OperandCountAtLeastQuestion::get(uniquer), + UnsignedAnswer::get(uniquer, count)}; + } /// Create a predicate comparing the name of an operation to a known value. Predicate getOperationName(StringRef name) { @@ -509,10 +577,15 @@ return {ResultCountQuestion::get(uniquer), UnsignedAnswer::get(uniquer, count)}; } + Predicate getResultCountAtLeast(unsigned count) { + return {ResultCountAtLeastQuestion::get(uniquer), + UnsignedAnswer::get(uniquer, count)}; + } /// Create a predicate comparing the type of an attribute or value to a known - /// type. - Predicate getTypeConstraint(Type type) { + /// type. The value is stored as either a TypeAttr, or an ArrayAttr of + /// TypeAttr. + Predicate getTypeConstraint(Attribute type) { return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, type)}; } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp @@ -17,6 +17,13 @@ Position::~Position() {} +/// Returns the depth of the first ancestor operation position. +unsigned Position::getOperationDepth() const { + if (auto *operationPos = dyn_cast(this)) + return operationPos->getDepth(); + return parent->getOperationDepth(); +} + //===----------------------------------------------------------------------===// // AttributePosition @@ -32,18 +39,8 @@ } //===----------------------------------------------------------------------===// -// OperationPosition - -OperationPosition *OperationPosition::get(StorageUniquer &uniquer, - ArrayRef index) { - assert(!index.empty() && "expected at least two indices"); - - // Set the parent position if this isn't the root. - Position *parent = nullptr; - if (index.size() > 1) { - auto *node = OperationPosition::get(uniquer, index.drop_back()); - parent = OperandPosition::get(uniquer, std::make_pair(node, index.back())); - } - return uniquer.get( - [parent](OperationPosition *node) { node->parent = parent; }, index); +// OperandGroupPosition + +OperandGroupPosition::OperandGroupPosition(const KeyTy &key) : Base(key) { + parent = std::get<0>(key); } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h @@ -190,6 +190,12 @@ using ChildMapT = llvm::MapVector>; ChildMapT &getChildren() { return children; } + /// Returns the child at the given index. + std::pair> &getChild(unsigned i) { + assert(i < children.size() && "invalid child index"); + return *std::next(children.begin(), i); + } + private: /// Switch predicate "answers" select the child. Answers that are not found /// default to the failure node. diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -28,7 +28,13 @@ /// Compares the depths of two positions. static bool comparePosDepth(Position *lhs, Position *rhs) { - return lhs->getIndex().size() < rhs->getIndex().size(); + return lhs->getOperationDepth() < rhs->getOperationDepth(); +} + +/// Returns the number of non-range elements within `values`. +static unsigned getNumNonRangeValues(ValueRange values) { + return llvm::count_if(values.getTypes(), + [](Type type) { return !type.isa(); }); } static void getTreePredicates(std::vector &predList, @@ -46,28 +52,50 @@ predList.emplace_back(pos, builder.getAttributeConstraint(value)); } -static void getTreePredicates(std::vector &predList, - Value val, PredicateBuilder &builder, - DenseMap &inputs, - OperandPosition *pos) { - assert(val.getType().isa() && "expected value type"); - - // Prevent traversal into a null value. - predList.emplace_back(pos, builder.getIsNotNull()); +/// Collect all of the predicates for the given operand position. +static void getOperandTreePredicates(std::vector &predList, + Value val, PredicateBuilder &builder, + DenseMap &inputs, + Position *pos) { + Type valueType = val.getType(); + bool isVariadic = valueType.isa(); // If this is a typed operand, add a type constraint. - if (auto in = val.getDefiningOp()) { - if (Value type = in.type()) - getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); - - // Otherwise, recurse into a result node. - } else if (auto resultOp = val.getDefiningOp()) { - OperationPosition *parentPos = builder.getParent(pos); - Position *resultPos = builder.getResult(parentPos, resultOp.index()); - predList.emplace_back(parentPos, builder.getIsNotNull()); - predList.emplace_back(resultPos, builder.getEqualTo(pos)); - getTreePredicates(predList, resultOp.parent(), builder, inputs, parentPos); - } + TypeSwitch(val.getDefiningOp()) + .Case([&](auto op) { + // Prevent traversal into a null value if the operand has a proper + // index. + if (std::is_same::value || + cast(pos)->getOperandGroupNumber()) + predList.emplace_back(pos, builder.getIsNotNull()); + + if (Value type = op.type()) + getTreePredicates(predList, type, builder, inputs, + builder.getType(pos)); + }) + .Case([&](auto op) { + Optional index = op.index(); + + // Prevent traversal into a null value if the result has a proper index. + if (index) + predList.emplace_back(pos, builder.getIsNotNull()); + + // Get the parent operation of this operand. + OperationPosition *parentPos = builder.getOperandDefiningOp(pos); + predList.emplace_back(parentPos, builder.getIsNotNull()); + + // Ensure that the operands match the corresponding results of the + // parent operation. + Position *resultPos = nullptr; + if (std::is_same::value) + resultPos = builder.getResult(parentPos, *index); + else + resultPos = builder.getResultGroup(parentPos, index, isVariadic); + predList.emplace_back(resultPos, builder.getEqualTo(pos)); + + // Collect the predicates of the parent operation. + getTreePredicates(predList, op.parent(), builder, inputs, parentPos); + }); } static void getTreePredicates(std::vector &predList, @@ -86,11 +114,27 @@ if (Optional opName = op.name()) predList.emplace_back(pos, builder.getOperationName(*opName)); - // Check that the operation has the proper number of operands and results. + // Check that the operation has the proper number of operands. If there are + // any variable length operands, we check a minimum instead of an exact count. OperandRange operands = op.operands(); + unsigned minOperands = getNumNonRangeValues(operands); + if (minOperands != operands.size()) { + if (minOperands) + predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands)); + } else { + predList.emplace_back(pos, builder.getOperandCount(minOperands)); + } + + // Check that the operation has the proper number of results. If there are + // any variable length results, we check a minimum instead of an exact count. OperandRange types = op.types(); - predList.emplace_back(pos, builder.getOperandCount(operands.size())); - predList.emplace_back(pos, builder.getResultCount(types.size())); + unsigned minResults = getNumNonRangeValues(types); + if (minResults != types.size()) { + if (minResults) + predList.emplace_back(pos, builder.getResultCountAtLeast(minResults)); + } else { + predList.emplace_back(pos, builder.getResultCount(types.size())); + } // Recurse into any attributes, operands, or results. for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { @@ -99,15 +143,47 @@ builder.getAttribute(opPos, std::get<0>(it).cast().getValue())); } - for (auto operandIt : llvm::enumerate(operands)) { - getTreePredicates(predList, operandIt.value(), builder, inputs, - builder.getOperand(opPos, operandIt.index())); + + // Process the operands and results of the operation. For all values up to + // the first variable length value, we use the concrete operand/result + // number. After that, we use the "group" given that we can't know the + // concrete indices until runtime. If there is only one variadic operand + // group, we treat it as all of the operands/results of the operation. + /// Operands. + if (operands.size() == 1 && operands[0].getType().isa()) { + getTreePredicates(predList, operands.front(), builder, inputs, + builder.getAllOperands(opPos)); + } else { + bool foundVariableLength = false; + for (auto operandIt : llvm::enumerate(operands)) { + bool isVariadic = operandIt.value().getType().isa(); + foundVariableLength |= isVariadic; + + Position *pos = + foundVariableLength + ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic) + : builder.getOperand(opPos, operandIt.index()); + getTreePredicates(predList, operandIt.value(), builder, inputs, pos); + } } - for (auto &resultIt : llvm::enumerate(types)) { - auto *resultPos = builder.getResult(pos, resultIt.index()); - predList.emplace_back(resultPos, builder.getIsNotNull()); - getTreePredicates(predList, resultIt.value(), builder, inputs, - builder.getType(resultPos)); + /// Results. + if (types.size() == 1 && types[0].getType().isa()) { + getTreePredicates(predList, types.front(), builder, inputs, + builder.getType(builder.getAllResults(opPos))); + } else { + bool foundVariableLength = false; + for (auto &resultIt : llvm::enumerate(types)) { + bool isVariadic = resultIt.value().getType().isa(); + foundVariableLength |= isVariadic; + + auto *resultPos = + foundVariableLength + ? builder.getResultGroup(pos, resultIt.index(), isVariadic) + : builder.getResult(pos, resultIt.index()); + predList.emplace_back(resultPos, builder.getIsNotNull()); + getTreePredicates(predList, resultIt.value(), builder, inputs, + builder.getType(resultPos)); + } } } @@ -115,12 +191,14 @@ Value val, PredicateBuilder &builder, DenseMap &inputs, TypePosition *pos) { - assert(val.getType().isa() && "expected value type"); - pdl::TypeOp typeOp = cast(val.getDefiningOp()); - // Check for a constraint on a constant type. - if (Optional type = typeOp.type()) - predList.emplace_back(pos, builder.getTypeConstraint(*type)); + if (pdl::TypeOp typeOp = val.getDefiningOp()) { + if (Attribute type = typeOp.typeAttr()) + predList.emplace_back(pos, builder.getTypeConstraint(type)); + } else if (pdl::TypesOp typeOp = val.getDefiningOp()) { + if (Attribute typeAttr = typeOp.typesAttr()) + predList.emplace_back(pos, builder.getTypeConstraint(typeAttr)); + } } /// Collect the tree predicates anchored at the given value. @@ -133,8 +211,8 @@ if (!it.second) { // If this is an input value that has been visited in the tree, add a // constraint to ensure that both instances refer to the same value. - if (isa( - val.getDefiningOp())) { + if (isa(val.getDefiningOp())) { auto minMaxPositions = std::minmax(pos, it.first->second, comparePosDepth); predList.emplace_back(minMaxPositions.second, @@ -144,9 +222,11 @@ } TypeSwitch(pos) - .Case([&](auto *derivedPos) { - getTreePredicates(predList, val, builder, inputs, derivedPos); + .Case([&](auto *pos) { + getTreePredicates(predList, val, builder, inputs, pos); + }) + .Case([&](auto *pos) { + getOperandTreePredicates(predList, val, builder, inputs, pos); }) .Default([](auto *) { llvm_unreachable("unexpected position kind"); }); } @@ -180,11 +260,30 @@ Position *&resultPos = inputs[op]; if (resultPos) return; + + // Ensure that the result isn't null. auto *parentPos = cast(inputs.lookup(op.parent())); resultPos = builder.getResult(parentPos, op.index()); predList.emplace_back(resultPos, builder.getIsNotNull()); } +static void getResultPredicates(pdl::ResultsOp op, + std::vector &predList, + PredicateBuilder &builder, + DenseMap &inputs) { + Position *&resultPos = inputs[op]; + if (resultPos) + return; + + // Ensure that the result isn't null if the result has an index. + auto *parentPos = cast(inputs.lookup(op.parent())); + bool isVariadic = op.getType().isa(); + Optional index = op.index(); + resultPos = builder.getResultGroup(parentPos, index, isVariadic); + if (index) + predList.emplace_back(resultPos, builder.getIsNotNull()); +} + /// Collect all of the predicates that cannot be determined via walking the /// tree. static void getNonTreePredicates(pdl::PatternOp pattern, @@ -192,10 +291,13 @@ PredicateBuilder &builder, DenseMap &inputs) { for (Operation &op : pattern.body().getOps()) { - if (auto constraintOp = dyn_cast(&op)) - getConstraintPredicates(constraintOp, predList, builder, inputs); - else if (auto resultOp = dyn_cast(&op)) - getResultPredicates(resultOp, predList, builder, inputs); + TypeSwitch(&op) + .Case([&](auto constraintOp) { + getConstraintPredicates(constraintOp, predList, builder, inputs); + }) + .Case([&](auto resultOp) { + getResultPredicates(resultOp, predList, builder, inputs); + }); } } @@ -254,10 +356,10 @@ // * lower position dependency // * lower predicate dependency auto *rhsPos = rhs.position; - return std::make_tuple(primary, secondary, rhsPos->getIndex().size(), + return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(), rhsPos->getKind(), rhs.question->getKind()) > std::make_tuple(rhs.primary, rhs.secondary, - position->getIndex().size(), position->getKind(), + position->getOperationDepth(), position->getKind(), question->getKind()); } }; diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -29,28 +29,12 @@ // pdl_interp::CreateOperationOp //===----------------------------------------------------------------------===// -static ParseResult parseCreateOperationOp(OpAsmParser &p, - OperationState &state) { - if (p.parseOptionalAttrDict(state.attributes)) - return failure(); +static ParseResult parseCreateOperationOpAttributes( + OpAsmParser &p, SmallVectorImpl &attrOperands, + ArrayAttr &attrNamesAttr) { Builder &builder = p.getBuilder(); - - // Parse the operation name. - StringAttr opName; - if (p.parseAttribute(opName, "name", state.attributes)) - return failure(); - - // Parse the operands. - SmallVector operands; - if (p.parseLParen() || p.parseOperandList(operands) || p.parseRParen() || - p.resolveOperands(operands, builder.getType(), - state.operands)) - return failure(); - - // Parse the attributes. SmallVector attrNames; if (succeeded(p.parseOptionalLBrace())) { - SmallVector attrOps; do { StringAttr nameAttr; OpAsmParser::OperandType operand; @@ -58,60 +42,35 @@ p.parseOperand(operand)) return failure(); attrNames.push_back(nameAttr); - attrOps.push_back(operand); + attrOperands.push_back(operand); } while (succeeded(p.parseOptionalComma())); - - if (p.parseRBrace() || - p.resolveOperands(attrOps, builder.getType(), - state.operands)) - return failure(); - } - state.addAttribute("attributeNames", builder.getArrayAttr(attrNames)); - state.addTypes(builder.getType()); - - // Parse the result types. - SmallVector opResultTypes; - if (p.parseArrow()) - return failure(); - if (succeeded(p.parseOptionalLParen())) { - if (p.parseRParen()) + if (p.parseRBrace()) return failure(); - } else if (p.parseOperandList(opResultTypes) || - p.resolveOperands(opResultTypes, builder.getType(), - state.operands)) { - return failure(); } - - int32_t operandSegmentSizes[] = {static_cast(operands.size()), - static_cast(attrNames.size()), - static_cast(opResultTypes.size())}; - state.addAttribute("operand_segment_sizes", - builder.getI32VectorAttr(operandSegmentSizes)); + attrNamesAttr = builder.getArrayAttr(attrNames); return success(); } -static void print(OpAsmPrinter &p, CreateOperationOp op) { - p << "pdl_interp.create_operation "; - p.printOptionalAttrDict(op.getAttrs(), - {"attributeNames", "name", "operand_segment_sizes"}); - p << '"' << op.name() << "\"(" << op.operands() << ')'; +static void printCreateOperationOpAttributes(OpAsmPrinter &p, + CreateOperationOp op, + OperandRange attrArgs, + ArrayAttr attrNames) { + if (attrNames.empty()) + return; + p << " {"; + interleaveComma(llvm::seq(0, attrNames.size()), p, + [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); + p << '}'; +} - // Emit the optional attributes. - ArrayAttr attrNames = op.attributeNames(); - if (!attrNames.empty()) { - Operation::operand_range attrArgs = op.attributes(); - p << " {"; - interleaveComma(llvm::seq(0, attrNames.size()), p, - [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); - p << '}'; - } +//===----------------------------------------------------------------------===// +// pdl_interp::GetValueTypeOp +//===----------------------------------------------------------------------===// - // Print the result type constraints of the operation. - auto types = op.types(); - if (types.empty()) - p << " -> ()"; - else - p << " -> " << op.types(); +/// Given the result type of a `GetValueTypeOp`, return the expected input type. +static Type getGetValueTypeOpValueType(Type type) { + Type valueTy = pdl::ValueType::get(type.getContext()); + return type.isa() ? pdl::RangeType::get(valueTy) : valueTy; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -208,7 +208,7 @@ void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); - void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); @@ -487,7 +487,7 @@ pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp, pdl_interp::GetResultOp, pdl_interp::GetValueTypeOp, - pdl_interp::InferredTypeOp, pdl_interp::IsNotNullOp, + pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, @@ -615,9 +615,9 @@ ByteCodeWriter &writer) { writer.append(OpCode::GetValueType, op.result(), op.value()); } -void Generator::generate(pdl_interp::InferredTypeOp op, +void Generator::generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer) { - // InferType maps to a null type as a marker for inferring a result type. + // InferType maps to a null type as a marker for inferring result types. getMemIndex(op.type()) = getMemIndex(Type()); } void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { @@ -980,16 +980,12 @@ state.name.getAbstractOperation()->getInterface(); // TODO: Handle failure. - SmallVector inferredTypes; + state.types.clear(); if (failed(concept->inferReturnTypes( state.getContext(), state.location, state.operands, state.attributes.getDictionary(state.getContext()), state.regions, - inferredTypes))) + state.types))) return; - - for (unsigned i = 0, e = state.types.size(); i != e; ++i) - if (!state.types[i]) - state.types[i] = inferredTypes[i]; } Operation *resultOp = rewriter.createOperation(state); memory[memIndex] = resultOp; diff --git a/mlir/lib/TableGen/Predicate.cpp b/mlir/lib/TableGen/Predicate.cpp --- a/mlir/lib/TableGen/Predicate.cpp +++ b/mlir/lib/TableGen/Predicate.cpp @@ -133,6 +133,23 @@ using Subst = std::pair; } // end anonymous namespace +/// Perform the given substitutions on 'str' in-place. +static void performSubstitutions(std::string &str, + ArrayRef substitutions) { + // Apply all parent substitutions from innermost to outermost. + for (const auto &subst : llvm::reverse(substitutions)) { + auto pos = str.find(std::string(subst.first)); + while (pos != std::string::npos) { + str.replace(pos, subst.first.size(), std::string(subst.second)); + // Skip the newly inserted substring, which itself may consider the + // pattern to match. + pos += subst.second.size(); + // Find the next possible match position. + pos = str.find(std::string(subst.first), pos); + } + } +} + // Build the predicate tree starting from the top-level predicate, which may // have children, and perform leaf substitutions inplace. Note that after // substitution, nodes are still pointing to the original TableGen record. @@ -147,19 +164,7 @@ rootNode->predicate = &root; if (!root.isCombined()) { rootNode->expr = root.getCondition(); - // Apply all parent substitutions from innermost to outermost. - for (const auto &subst : llvm::reverse(substitutions)) { - auto pos = rootNode->expr.find(std::string(subst.first)); - while (pos != std::string::npos) { - rootNode->expr.replace(pos, subst.first.size(), - std::string(subst.second)); - // Skip the newly inserted substring, which itself may consider the - // pattern to match. - pos += subst.second.size(); - // Find the next possible match position. - pos = rootNode->expr.find(std::string(subst.first), pos); - } - } + performSubstitutions(rootNode->expr, substitutions); return rootNode; } @@ -170,12 +175,14 @@ const auto &substPred = static_cast(root); allSubstitutions.push_back( {substPred.getPattern(), substPred.getReplacement()}); - } - // If the current predicate is a ConcatPred, record the prefix and suffix. - else if (rootNode->kind == PredCombinerKind::Concat) { + + // If the current predicate is a ConcatPred, record the prefix and suffix. + } else if (rootNode->kind == PredCombinerKind::Concat) { const auto &concatPred = static_cast(root); rootNode->prefix = std::string(concatPred.getPrefix()); + performSubstitutions(rootNode->prefix, substitutions); rootNode->suffix = std::string(concatPred.getSuffix()); + performSubstitutions(rootNode->suffix, substitutions); } // Build child subtrees. diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -103,6 +103,59 @@ // ----- +// CHECK-LABEL: module @variadic_inputs +module @variadic_inputs { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT]] is at_least 2 + + // The first operand has a known index. + // CHECK-DAG: %[[INPUT:.*]] = pdl_interp.get_operand 0 of %[[ROOT]] + // CHECK-DAG: pdl_interp.is_not_null %[[INPUT]] : !pdl.value + + // The second operand is a group of unknown size, with a type constraint. + // CHECK-DAG: %[[VAR_INPUTS:.*]] = pdl_interp.get_operands 1 of %[[ROOT]] : !pdl.range + // CHECK-DAG: pdl_interp.is_not_null %[[VAR_INPUTS]] : !pdl.range + + // CHECK-DAG: %[[INPUT_TYPE:.*]] = pdl_interp.get_value_type of %[[VAR_INPUTS]] : !pdl.range + // CHECK-DAG: pdl_interp.check_types %[[INPUT_TYPE]] are [i64] + + // The third operand is at an unknown offset due to operand 2, but is expected + // to be of size 1. + // CHECK-DAG: %[[INPUT2:.*]] = pdl_interp.get_operands 2 of %[[ROOT]] : !pdl.value + // CHECK-DAG: pdl_interp.are_equal %[[INPUT]], %[[INPUT2]] : !pdl.value + pdl.pattern : benefit(1) { + %types = pdl.types : [i64] + %inputs = pdl.operands : %types + %input = pdl.operand + %root = pdl.operation(%input, %inputs, %input : !pdl.value, !pdl.range, !pdl.value) + pdl.rewrite %root with "rewriter" + } +} + +// ----- + +// CHECK-LABEL: module @single_operand_range +module @single_operand_range { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + + // Check that the operand range is treated as all of the operands of the + // operation. + // CHECK-DAG: %[[RESULTS:.*]] = pdl_interp.get_operands of %[[ROOT]] + // CHECK-DAG: %[[RESULT_TYPES:.*]] = pdl_interp.get_value_type of %[[RESULTS]] : !pdl.range + // CHECK-DAG: pdl_interp.check_types %[[RESULT_TYPES]] are [i64] + + // The operand count is unknown, so there is no need to check for it. + // CHECK-NOT: pdl_interp.check_operand_count + pdl.pattern : benefit(1) { + %types = pdl.types : [i64] + %operands = pdl.operands : %types + %root = pdl.operation(%operands : !pdl.range) + pdl.rewrite %root with "rewriter" + } +} + +// ----- + // CHECK-LABEL: module @results module @results { // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) @@ -127,7 +180,58 @@ // ----- -// CHECK-LABEL: module @results +// CHECK-LABEL: module @variadic_results +module @variadic_results { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK-DAG: pdl_interp.check_result_count of %[[ROOT]] is at_least 2 + + // The first result has a known index. + // CHECK-DAG: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]] + // CHECK-DAG: pdl_interp.is_not_null %[[RESULT]] : !pdl.value + + // The second result is a group of unknown size, with a type constraint. + // CHECK-DAG: %[[VAR_RESULTS:.*]] = pdl_interp.get_results 1 of %[[ROOT]] : !pdl.range + // CHECK-DAG: pdl_interp.is_not_null %[[VAR_RESULTS]] : !pdl.range + + // CHECK-DAG: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[VAR_RESULTS]] : !pdl.range + // CHECK-DAG: pdl_interp.check_types %[[RESULT_TYPE]] are [i64] + + // The third result is at an unknown offset due to result 1, but is expected + // to be of size 1. + // CHECK-DAG: %[[RESULT2:.*]] = pdl_interp.get_results 2 of %[[ROOT]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[RESULT2]] : !pdl.value + pdl.pattern : benefit(1) { + %types = pdl.types : [i64] + %type = pdl.type + %root = pdl.operation -> (%type, %types, %type : !pdl.type, !pdl.range, !pdl.type) + pdl.rewrite %root with "rewriter" + } +} + +// ----- + +// CHECK-LABEL: module @single_result_range +module @single_result_range { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + + // Check that the result range is treated as all of the results of the + // operation. + // CHECK-DAG: %[[RESULTS:.*]] = pdl_interp.get_results of %[[ROOT]] + // CHECK-DAG: %[[RESULT_TYPES:.*]] = pdl_interp.get_value_type of %[[RESULTS]] : !pdl.range + // CHECK-DAG: pdl_interp.check_types %[[RESULT_TYPES]] are [i64] + + // The result count is unknown, so there is no need to check for it. + // CHECK-NOT: pdl_interp.check_result_count + pdl.pattern : benefit(1) { + %types = pdl.types : [i64] + %root = pdl.operation -> (%types : !pdl.range) + pdl.rewrite %root with "rewriter" + } +} + +// ----- + +// CHECK-LABEL: module @results_as_operands module @results_as_operands { // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) @@ -160,8 +264,29 @@ // ----- -// CHECK-LABEL: module @switch_result_types -module @switch_result_types { +// CHECK-LABEL: module @single_result_range_as_operands +module @single_result_range_as_operands { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK-DAG: %[[OPERANDS:.*]] = pdl_interp.get_operands of %[[ROOT]] : !pdl.range + // CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[OPERANDS]] : !pdl.range + // CHECK-DAG: pdl_interp.is_not_null %[[OP]] + // CHECK-DAG: %[[RESULTS:.*]] = pdl_interp.get_results of %[[OP]] : !pdl.range + // CHECK-DAG: pdl_interp.are_equal %[[RESULTS]], %[[OPERANDS]] : !pdl.range + + pdl.pattern : benefit(1) { + %types = pdl.types + %inputOp = pdl.operation -> (%types : !pdl.range) + %results = pdl.results of %inputOp + + %root = pdl.operation(%results : !pdl.range) + pdl.rewrite %root with "rewriter" + } +} + +// ----- + +// CHECK-LABEL: module @switch_single_result_type +module @switch_single_result_type { // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) // CHECK: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]] // CHECK: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]] @@ -178,6 +303,84 @@ } } +// ----- + +// CHECK-LABEL: module @switch_result_types +module @switch_result_types { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[RESULTS:.*]] = pdl_interp.get_results of %[[ROOT]] + // CHECK: %[[RESULT_TYPES:.*]] = pdl_interp.get_value_type of %[[RESULTS]] + // CHECK: pdl_interp.switch_types %[[RESULT_TYPES]] to {{\[\[}}i32], [i64, i32]] + pdl.pattern : benefit(1) { + %types = pdl.types : [i32] + %root = pdl.operation -> (%types : !pdl.range) + pdl.rewrite %root with "rewriter" + } + pdl.pattern : benefit(1) { + %types = pdl.types : [i64, i32] + %root = pdl.operation -> (%types : !pdl.range) + pdl.rewrite %root with "rewriter" + } +} + +// ----- + +// CHECK-LABEL: module @switch_operand_count_at_least +module @switch_operand_count_at_least { + // Check that when there are multiple "at_least" checks, the failure branch + // goes to the next one in increasing order. + + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: pdl_interp.check_operand_count of %[[ROOT]] is at_least 1 -> ^[[PATTERN_1_NEXT_BLOCK:.*]], + // CHECK: ^bb2: + // CHECK-NEXT: pdl_interp.check_operand_count of %[[ROOT]] is at_least 2 + // CHECK: ^[[PATTERN_1_NEXT_BLOCK]]: + // CHECK-NEXT: {{.*}} -> ^{{.*}}, ^bb2 + pdl.pattern : benefit(1) { + %operand = pdl.operand + %operands = pdl.operands + %root = pdl.operation(%operand, %operands : !pdl.value, !pdl.range) + pdl.rewrite %root with "rewriter" + } + pdl.pattern : benefit(1) { + %operand = pdl.operand + %operand2 = pdl.operand + %operands = pdl.operands + %root = pdl.operation(%operand, %operand2, %operands : !pdl.value, !pdl.value, !pdl.range) + pdl.rewrite %root with "rewriter" + } +} + +// ----- + +// CHECK-LABEL: module @switch_result_count_at_least +module @switch_result_count_at_least { + // Check that when there are multiple "at_least" checks, the failure branch + // goes to the next one in increasing order. + + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: pdl_interp.check_result_count of %[[ROOT]] is at_least 1 -> ^[[PATTERN_1_NEXT_BLOCK:.*]], + // CHECK: ^[[PATTERN_2_BLOCK:[a-zA-Z_0-9]*]]: + // CHECK: pdl_interp.check_result_count of %[[ROOT]] is at_least 2 + // CHECK: ^[[PATTERN_1_NEXT_BLOCK]]: + // CHECK-NEXT: pdl_interp.get_result + // CHECK-NEXT: pdl_interp.is_not_null {{.*}} -> ^{{.*}}, ^[[PATTERN_2_BLOCK]] + pdl.pattern : benefit(1) { + %type = pdl.type + %types = pdl.types + %root = pdl.operation -> (%type, %types : !pdl.type, !pdl.range) + pdl.rewrite %root with "rewriter" + } + pdl.pattern : benefit(1) { + %type = pdl.type + %type2 = pdl.type + %types = pdl.types + %root = pdl.operation -> (%type, %type2, %types : !pdl.type, !pdl.type, !pdl.range) + pdl.rewrite %root with "rewriter" + } +} + + // ----- // CHECK-LABEL: module @predicate_ordering diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir @@ -37,7 +37,7 @@ // CHECK: module @rewriters // CHECK: func @pdl_generated_rewriter(%[[ATTR:.*]]: !pdl.attribute, %[[ROOT:.*]]: !pdl.operation) // CHECK: %[[ATTR1:.*]] = pdl_interp.create_attribute true - // CHECK: pdl_interp.create_operation "foo.op"() {"attr" = %[[ATTR]], "attr1" = %[[ATTR1]]} + // CHECK: pdl_interp.create_operation "foo.op" {"attr" = %[[ATTR]], "attr1" = %[[ATTR1]]} pdl.pattern : benefit(1) { %attr = pdl.attribute %root = pdl.operation "foo.op" {"attr" = %attr} @@ -55,9 +55,9 @@ module @operation_operands { // CHECK: module @rewriters // CHECK: func @pdl_generated_rewriter(%[[OPERAND:.*]]: !pdl.value, %[[ROOT:.*]]: !pdl.operation) - // CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation "foo.op"(%[[OPERAND]]) + // CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation "foo.op"(%[[OPERAND]] : !pdl.value) // CHECK: %[[OPERAND1:.*]] = pdl_interp.get_result 0 of %[[NEWOP]] - // CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]]) + // CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]] : !pdl.value) pdl.pattern : benefit(1) { %operand = pdl.operand %root = pdl.operation "foo.op"(%operand : !pdl.value) @@ -77,9 +77,9 @@ module @operation_operands { // CHECK: module @rewriters // CHECK: func @pdl_generated_rewriter(%[[OPERAND:.*]]: !pdl.value, %[[ROOT:.*]]: !pdl.operation) - // CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation "foo.op"(%[[OPERAND]]) + // CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation "foo.op"(%[[OPERAND]] : !pdl.value) // CHECK: %[[OPERAND1:.*]] = pdl_interp.get_result 0 of %[[NEWOP]] - // CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]]) + // CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]] : !pdl.value) pdl.pattern : benefit(1) { %operand = pdl.operand %root = pdl.operation "foo.op"(%operand : !pdl.value) @@ -95,11 +95,13 @@ // ----- -// CHECK-LABEL: module @operation_result_types -module @operation_result_types { +// CHECK-LABEL: module @operation_infer_types_from_replaceop +module @operation_infer_types_from_replaceop { // CHECK: module @rewriters - // CHECK: func @pdl_generated_rewriter(%[[TYPE:.*]]: !pdl.type, %[[TYPE1:.*]]: !pdl.type - // CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]], %[[TYPE1]] + // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation + // CHECK: %[[RESULTS:.*]] = pdl_interp.get_results of %[[ROOT]] + // CHECK: %[[RESULT_TYPES:.*]] = pdl_interp.get_value_type of %[[RESULTS]] + // CHECK: pdl_interp.create_operation "foo.op" -> (%[[RESULT_TYPES]] : !pdl.range) pdl.pattern : benefit(1) { %rootType = pdl.type %rootType1 = pdl.type @@ -114,13 +116,46 @@ // ----- +// CHECK-LABEL: module @operation_infer_types_from_otherop_individual_results +module @operation_infer_types_from_otherop_individual_results { + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter(%[[TYPE:.*]]: !pdl.type, %[[TYPES:.*]]: !pdl.range + // CHECK: pdl_interp.create_operation "foo.op" -> (%[[TYPE]], %[[TYPES]] : !pdl.type, !pdl.range) + pdl.pattern : benefit(1) { + %rootType = pdl.type + %rootTypes = pdl.types + %root = pdl.operation "foo.op" -> (%rootType, %rootTypes : !pdl.type, !pdl.range) + pdl.rewrite %root { + %newOp = pdl.operation "foo.op" -> (%rootType, %rootTypes : !pdl.type, !pdl.range) + } + } +} + +// ----- + +// CHECK-LABEL: module @operation_infer_types_from_otherop_results +module @operation_infer_types_from_otherop_results { + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter(%[[TYPES:.*]]: !pdl.range + // CHECK: pdl_interp.create_operation "foo.op" -> (%[[TYPES]] : !pdl.range) + pdl.pattern : benefit(1) { + %rootTypes = pdl.types + %root = pdl.operation "foo.op" -> (%rootTypes : !pdl.range) + pdl.rewrite %root { + %newOp = pdl.operation "foo.op" -> (%rootTypes : !pdl.range) + } + } +} + +// ----- + // CHECK-LABEL: module @replace_with_op module @replace_with_op { // CHECK: module @rewriters // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation) // CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation - // CHECK: %[[OP_RESULT:.*]] = pdl_interp.get_result 0 of %[[NEWOP]] - // CHECK: pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]]) + // CHECK: %[[RESULTS:.*]] = pdl_interp.get_results of %[[NEWOP]] + // CHECK: pdl_interp.replace %[[ROOT]] with (%[[RESULTS]] : !pdl.range) pdl.pattern : benefit(1) { %type = pdl.type : i32 %root = pdl.operation "foo.op" -> (%type : !pdl.type) @@ -136,17 +171,19 @@ // CHECK-LABEL: module @replace_with_values module @replace_with_values { // CHECK: module @rewriters - // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation) + // CHECK: func @pdl_generated_rewriter({{.*}}, %[[ROOT:.*]]: !pdl.operation) // CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation - // CHECK: %[[OP_RESULT:.*]] = pdl_interp.get_result 0 of %[[NEWOP]] - // CHECK: pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]]) + // CHECK: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[NEWOP]] + // CHECK: %[[RESULTS:.*]] = pdl_interp.get_results 1 of %[[NEWOP]] + // CHECK: pdl_interp.replace %[[ROOT]] with (%[[RESULT]], %[[RESULTS]] : !pdl.value, !pdl.range) pdl.pattern : benefit(1) { - %type = pdl.type : i32 - %root = pdl.operation "foo.op" -> (%type : !pdl.type) + %types = pdl.types + %root = pdl.operation "foo.op" -> (%types : !pdl.range) pdl.rewrite %root { - %newOp = pdl.operation "foo.op" -> (%type : !pdl.type) + %newOp = pdl.operation "foo.op" -> (%types : !pdl.range) %newResult = pdl.result 0 of %newOp - pdl.replace %root with (%newResult : !pdl.value) + %newResults = pdl.results 1 of %newOp + pdl.replace %root with (%newResult, %newResults : !pdl.value, !pdl.range) } } } @@ -175,14 +212,13 @@ // CHECK: module @rewriters // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation) // CHECK: %[[TYPE:.*]] = pdl_interp.apply_rewrite "functor" [true](%[[ROOT]] : !pdl.operation) : !pdl.type - // CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]] + // CHECK: pdl_interp.create_operation "foo.op" -> (%[[TYPE]] : !pdl.type) pdl.pattern : benefit(1) { %type = pdl.type %root = pdl.operation "foo.op" -> (%type : !pdl.type) pdl.rewrite %root { %newType = pdl.apply_native_rewrite "functor"[true](%root : !pdl.operation) : !pdl.type %newOp = pdl.operation "foo.op" -> (%newType : !pdl.type) - pdl.replace %root with %newOp } } } diff --git a/mlir/test/Dialect/PDLInterp/ops.mlir b/mlir/test/Dialect/PDLInterp/ops.mlir --- a/mlir/test/Dialect/PDLInterp/ops.mlir +++ b/mlir/test/Dialect/PDLInterp/ops.mlir @@ -10,16 +10,16 @@ %input: !pdl.value, %type: !pdl.type) { // attributes, operands, and results - %op0 = pdl_interp.create_operation "foo.op"(%input) {"attr" = %attribute} -> %type + %op0 = pdl_interp.create_operation "foo.op"(%input : !pdl.value) {"attr" = %attribute} -> (%type : !pdl.type) // attributes, and results - %op1 = pdl_interp.create_operation "foo.op"() {"attr" = %attribute} -> %type + %op1 = pdl_interp.create_operation "foo.op" {"attr" = %attribute} -> (%type : !pdl.type) // attributes - %op2 = pdl_interp.create_operation "foo.op"() {"attr" = %attribute, "attr1" = %attribute} -> () + %op2 = pdl_interp.create_operation "foo.op" {"attr" = %attribute, "attr1" = %attribute} // operands, and results - %op3 = pdl_interp.create_operation "foo.op"(%input) -> %type + %op3 = pdl_interp.create_operation "foo.op"(%input : !pdl.value) -> (%type : !pdl.type) pdl_interp.finalize } diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -25,7 +25,7 @@ module @rewriters { func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.replaced_by_pattern"() -> () + %op = pdl_interp.create_operation "test.replaced_by_pattern" pdl_interp.erase %root pdl_interp.finalize } @@ -122,7 +122,7 @@ module @rewriters { func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.success"() -> () + %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize } @@ -157,7 +157,7 @@ module @rewriters { func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.success"() -> () + %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize } @@ -190,7 +190,7 @@ module @rewriters { func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.success"() -> () + %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize } @@ -222,7 +222,7 @@ module @rewriters { func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.success"() -> () + %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize } @@ -256,7 +256,7 @@ module @rewriters { func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.success"() -> () + %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize } @@ -288,7 +288,7 @@ module @rewriters { func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.success"() -> () + %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize } @@ -325,7 +325,7 @@ module @rewriters { func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.success"() -> () + %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize } @@ -375,7 +375,7 @@ module @rewriters { func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.success"() -> () + %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize } @@ -425,8 +425,8 @@ ^pat1: %operand0 = pdl_interp.get_operand 0 of %root %operand4 = pdl_interp.get_operand 4 of %root - %defOp0 = pdl_interp.get_defining_op of %operand0 - %defOp4 = pdl_interp.get_defining_op of %operand4 + %defOp0 = pdl_interp.get_defining_op of %operand0 : !pdl.value + %defOp4 = pdl_interp.get_defining_op of %operand4 : !pdl.value pdl_interp.are_equal %defOp0, %defOp4 : !pdl.operation -> ^pat2, ^end ^pat2: @@ -438,7 +438,7 @@ module @rewriters { func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.success"() -> () + %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize } @@ -476,8 +476,8 @@ ^pat1: %result0 = pdl_interp.get_result 0 of %root %result4 = pdl_interp.get_result 4 of %root - %result0_type = pdl_interp.get_value_type of %result0 - %result4_type = pdl_interp.get_value_type of %result4 + %result0_type = pdl_interp.get_value_type of %result0 : !pdl.type + %result4_type = pdl_interp.get_value_type of %result4 : !pdl.type pdl_interp.are_equal %result0_type, %result4_type : !pdl.type -> ^pat2, ^end ^pat2: @@ -489,7 +489,7 @@ module @rewriters { func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.success"() -> () + %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize } @@ -513,7 +513,7 @@ // Fully tested within the tests for other operations. //===----------------------------------------------------------------------===// -// pdl_interp::InferredTypeOp +// pdl_interp::InferredTypesOp //===----------------------------------------------------------------------===// // Fully tested within the tests for other operations. @@ -549,7 +549,7 @@ pdl_interp.finalize } func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.success"() -> () + %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize } @@ -582,7 +582,7 @@ module @rewriters { func @success(%root : !pdl.operation) { %operand = pdl_interp.get_operand 0 of %root - pdl_interp.replace %root with (%operand) + pdl_interp.replace %root with (%operand : !pdl.value) pdl_interp.finalize } } @@ -622,7 +622,7 @@ module @rewriters { func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.success"() -> () + %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize } @@ -657,7 +657,7 @@ module @rewriters { func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.success"() -> () + %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize } @@ -693,7 +693,7 @@ module @rewriters { func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.success"() -> () + %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize } @@ -728,7 +728,7 @@ module @rewriters { func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.success"() -> () + %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize } @@ -768,7 +768,7 @@ module @rewriters { func @success(%root : !pdl.operation) { - %op = pdl_interp.create_operation "test.success"() -> () + %op = pdl_interp.create_operation "test.success" pdl_interp.erase %root pdl_interp.finalize } diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -136,7 +136,7 @@ // DEF: if (!((tblgen_function_attr.isa<::mlir::FlatSymbolRefAttr>()))) // DEF: if (!(((tblgen_some_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_some_type_attr.cast<::mlir::TypeAttr>().getValue().isa())))) // DEF: if (!((tblgen_array_attr.isa<::mlir::ArrayAttr>()))) -// DEF: if (!(((tblgen_some_attr_array.isa<::mlir::ArrayAttr>())) && (::llvm::all_of(tblgen_some_attr_array.cast<::mlir::ArrayAttr>(), [](::mlir::Attribute attr) { return (some-condition); })))) +// DEF: if (!(((tblgen_some_attr_array.isa<::mlir::ArrayAttr>())) && (::llvm::all_of(tblgen_some_attr_array.cast<::mlir::ArrayAttr>(), [&](::mlir::Attribute attr) { return (some-condition); })))) // DEF: if (!(((tblgen_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_type_attr.cast<::mlir::TypeAttr>().getValue().isa<::mlir::Type>())))) // Test common attribute kind getters' return types