diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -409,14 +409,18 @@ let description = [{ `pdl_interp.create_operation` operations create an `Operation` instance with the specified attributes, operands, and result types. See `pdl.operation` - for a more detailed description on the interpretation of the arguments to - this operation. + for a more detailed description on the general interpretation of the arguments + to this operation. Example: ```mlir // Create an instance of a `foo.op` operation. %op = pdl_interp.create_operation "foo.op"(%arg0 : !pdl.value) {"attrA" = %attr0} -> (%type : !pdl.type) + + // Create an instance of a `foo.op` operation that has inferred result types + // (using the InferTypeOpInterface). + %op = pdl_interp.create_operation "foo.op"(%arg0 : !pdl.value) {"attrA" = %attr0} -> ``` }]; @@ -424,22 +428,26 @@ Variadic>:$inputOperands, Variadic:$inputAttributes, StrArrayAttr:$inputAttributeNames, - Variadic>:$inputResultTypes); + Variadic>:$inputResultTypes, + UnitAttr:$inferredResultTypes); let results = (outs PDL_Operation:$resultOp); let builders = [ OpBuilder<(ins "StringRef":$name, "ValueRange":$types, - "ValueRange":$operands, "ValueRange":$attributes, - "ArrayAttr":$attributeNames), [{ + "bool":$inferredResultTypes, "ValueRange":$operands, + "ValueRange":$attributes, "ArrayAttr":$attributeNames), [{ build($_builder, $_state, $_builder.getType(), name, - operands, attributes, attributeNames, types); + operands, attributes, attributeNames, types, inferredResultTypes); }]> ]; let assemblyFormat = [{ - $name (`(` $inputOperands^ `:` type($inputOperands) `)`)? + $name (`(` $inputOperands^ `:` type($inputOperands) `)`)? `` custom($inputAttributes, $inputAttributeNames) - (`->` `(` $inputResultTypes^ `:` type($inputResultTypes) `)`)? attr-dict + custom($inputResultTypes, type($inputResultTypes), + $inferredResultTypes) + attr-dict }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -961,33 +969,6 @@ ]; } -//===----------------------------------------------------------------------===// -// pdl_interp::InferredTypesOp -//===----------------------------------------------------------------------===// - -def PDLInterp_InferredTypesOp : PDLInterp_Op<"inferred_types"> { - let summary = "Generate a handle to a range of Types that are \"inferred\""; - let description = [{ - `pdl_interp.inferred_types` operations generate handles to ranges of types - that should be inferred. This signals to other operations, such as - `pdl_interp.create_operation`, that these types should be inferred. - - Example: - - ```mlir - %types = pdl_interp.inferred_types - ``` - }]; - let results = (outs PDL_RangeOf:$result); - let assemblyFormat = "attr-dict"; - let builders = [ - OpBuilder<(ins), [{ - build($_builder, $_state, - pdl::RangeType::get($_builder.getType())); - }]> - ]; -} - //===----------------------------------------------------------------------===// // pdl_interp::IsNotNullOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -100,11 +100,12 @@ function_ref mapRewriteValue); /// Generate the values used for resolving the result types of an operation - /// created within a dag rewriter region. + /// created within a dag rewriter region. If the result types of the operation + /// should be inferred, `hasInferredResultTypes` is set to true. void generateOperationResultTypeRewriter( - pdl::OperationOp op, SmallVectorImpl &types, - DenseMap &rewriteValues, - function_ref mapRewriteValue); + pdl::OperationOp op, function_ref mapRewriteValue, + SmallVectorImpl &types, DenseMap &rewriteValues, + bool &hasInferredResultTypes); /// A builder to use when generating interpreter operations. OpBuilder builder; @@ -707,15 +708,16 @@ for (Value attr : operationOp.attributes()) attributes.push_back(mapRewriteValue(attr)); + bool hasInferredResultTypes = false; SmallVector types; - generateOperationResultTypeRewriter(operationOp, types, rewriteValues, - mapRewriteValue); + generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types, + rewriteValues, hasInferredResultTypes); // Create the new operation. Location loc = operationOp.getLoc(); Value createdOp = builder.create( - loc, *operationOp.name(), types, operands, attributes, - operationOp.attributeNames()); + loc, *operationOp.name(), types, hasInferredResultTypes, operands, + attributes, operationOp.attributeNames()); rewriteValues[operationOp.op()] = createdOp; // Generate accesses for any results that have their types constrained. @@ -825,9 +827,9 @@ } void PatternLowering::generateOperationResultTypeRewriter( - pdl::OperationOp op, SmallVectorImpl &types, - DenseMap &rewriteValues, - function_ref mapRewriteValue) { + pdl::OperationOp op, function_ref mapRewriteValue, + SmallVectorImpl &types, DenseMap &rewriteValues, + bool &hasInferredResultTypes) { // Look for an operation that was replaced by `op`. The result types will be // inferred from the results that were replaced. Block *rewriterBlock = op->getBlock(); @@ -851,36 +853,54 @@ return; } - // Check if the operation has type inference support. - if (op.hasTypeInference()) { - types.push_back(builder.create(op.getLoc())); - return; - } - - // Otherwise, handle inference for each of the result types individually. + // Try to handle resolution for each of the result types individually. This is + // preferred over type inferrence because it will allow for us to use existing + // types directly, as opposed to trying to rebuild the type list. OperandRange resultTypeValues = op.types(); - types.reserve(resultTypeValues.size()); - for (const auto &it : llvm::enumerate(resultTypeValues)) { - Value resultType = it.value(); + auto tryResolveResultTypes = [&] { + types.reserve(resultTypeValues.size()); + for (const auto &it : llvm::enumerate(resultTypeValues)) { + Value resultType = it.value(); + + // Check for an already translated value. + if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { + types.push_back(existingRewriteValue); + continue; + } - // Check for an already translated value. - if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { - types.push_back(existingRewriteValue); - continue; - } + // Check for an input from the matcher. + if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { + types.push_back(mapRewriteValue(resultType)); + continue; + } - // Check for an input from the matcher. - if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { - types.push_back(mapRewriteValue(resultType)); - continue; + // Otherwise, we couldn't infer the result types. Bail out here to see if + // we can infer the types for this operation from another way. + types.clear(); + return failure(); } + return success(); + }; + if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes())) + return; - // The verifier asserts that the result types of each pdl.operation can be - // inferred. If we reach here, there is a bug either in the logic above or - // in the verifier for pdl.operation. - op->emitOpError() << "unable to infer result type for operation"; - llvm_unreachable("unable to infer result type for operation"); + // Otherwise, check if the operation has type inference support itself. + if (op.hasTypeInference()) { + hasInferredResultTypes = true; + return; } + + // If the types could not be inferred from any context and there weren't any + // explicit result types, assume the user actually meant for the operation to + // have no results. + if (resultTypeValues.empty()) + return; + + // The verifier asserts that the result types of each pdl.operation can be + // inferred. If we reach here, there is a bug either in the logic above or + // in the verifier for pdl.operation. + op->emitOpError() << "unable to infer result type for operation"; + llvm_unreachable("unable to infer result type for operation"); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -47,6 +47,23 @@ // pdl_interp::CreateOperationOp //===----------------------------------------------------------------------===// +LogicalResult CreateOperationOp::verify() { + if (!getInferredResultTypes()) + return success(); + if (!getInputResultTypes().empty()) { + return emitOpError("with inferred results cannot also have " + "explicit result types"); + } + OperationName opName(getName(), getContext()); + if (!opName.hasInterface()) { + return emitOpError() + << "has inferred results, but the created operation '" << opName + << "' does not support result type inference (or is not " + "registered)"; + } + return success(); +} + static ParseResult parseCreateOperationOpAttributes( OpAsmParser &p, SmallVectorImpl &attrOperands, @@ -82,6 +99,41 @@ p << '}'; } +static ParseResult parseCreateOperationOpResults( + OpAsmParser &p, + SmallVectorImpl &resultOperands, + SmallVectorImpl &resultTypes, UnitAttr &inferredResultTypes) { + if (failed(p.parseOptionalArrow())) + return success(); + + // Handle the case of inferred results. + if (succeeded(p.parseOptionalLess())) { + if (p.parseKeyword("inferred") || p.parseGreater()) + return failure(); + inferredResultTypes = p.getBuilder().getUnitAttr(); + return success(); + } + + // Otherwise, parse the explicit results. + return failure(p.parseLParen() || p.parseOperandList(resultOperands) || + p.parseColonTypeList(resultTypes) || p.parseRParen()); +} + +static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op, + OperandRange resultOperands, + TypeRange resultTypes, + UnitAttr inferredResultTypes) { + // Handle the case of inferred results. + if (inferredResultTypes) { + p << " -> "; + return; + } + + // Otherwise, handle the explicit results. + if (!resultTypes.empty()) + p << " -> (" << resultOperands << " : " << resultTypes << ")"; +} + //===----------------------------------------------------------------------===// // pdl_interp::ForEachOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -162,6 +162,10 @@ }; } // namespace +/// A marker used to indicate if an operation should infer types. +static constexpr ByteCodeField kInferTypesMarker = + std::numeric_limits::max(); + //===----------------------------------------------------------------------===// // ByteCode Generation //===----------------------------------------------------------------------===// @@ -273,7 +277,6 @@ void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); - void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); @@ -723,8 +726,7 @@ LLVM_DEBUG({ // The following list must contain all the operations that do not // produce any bytecode. - if (!isa(op)) + if (!isa(op)) writer.appendInline(op->getLoc()); }); TypeSwitch(op) @@ -742,11 +744,11 @@ pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, pdl_interp::GetResultOp, pdl_interp::GetResultsOp, pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp, - pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, - pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, - pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, - pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp, - pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( + pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, + pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, + pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp, + pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, + pdl_interp::SwitchResultCountOp>( [&](auto interpOp) { this->generate(interpOp, writer); }) .Default([](Operation *) { llvm_unreachable("unknown `pdl_interp` operation"); @@ -847,7 +849,13 @@ writer.append(static_cast(attributes.size())); for (auto it : llvm::zip(op.getInputAttributeNames(), attributes)) writer.append(std::get<0>(it), std::get<1>(it)); - writer.appendPDLValueList(op.getInputResultTypes()); + + // Add the result types. If the operation has inferred results, we use a + // marker "size" value. Otherwise, we add the list of explicit result types. + if (op.getInferredResultTypes()) + writer.append(kInferTypesMarker); + else + writer.appendPDLValueList(op.getInputResultTypes()); } void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { // Simply repoint the memory index of the result to the constant. @@ -955,12 +963,6 @@ writer.append(OpCode::GetValueType, op.getResult(), op.getValue()); } } - -void Generator::generate(pdl_interp::InferredTypesOp op, - ByteCodeWriter &writer) { - // InferType maps to a null type as a marker for inferring result types. - getMemIndex(op.getResult()) = getMemIndex(Type()); -} void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors()); } @@ -1526,30 +1528,31 @@ state.addAttribute(name, attr); } - for (unsigned i = 0, e = read(); i != e; ++i) { - if (read() == PDLValue::Kind::Type) { - state.types.push_back(read()); - continue; - } - - // If we find a null range, this signals that the types are infered. - if (TypeRange *resultTypes = read()) { - state.types.append(resultTypes->begin(), resultTypes->end()); - continue; - } - - // Handle the case where the operation has inferred types. + // Read in the result types. If the "size" is the sentinel value, this + // indicates that the result types should be inferred. + unsigned numResults = read(); + if (numResults == kInferTypesMarker) { InferTypeOpInterface::Concept *inferInterface = state.name.getRegisteredInfo()->getInterface(); + assert(inferInterface && + "expected operation to provide InferTypeOpInterface"); // TODO: Handle failure. - state.types.clear(); if (failed(inferInterface->inferReturnTypes( state.getContext(), state.location, state.operands, state.attributes.getDictionary(state.getContext()), state.regions, state.types))) return; - break; + } else { + // Otherwise, this is a fixed number of results. + for (unsigned i = 0; i != numResults; ++i) { + if (read() == PDLValue::Kind::Type) { + state.types.push_back(read()); + } else { + TypeRange *resultTypes = read(); + state.types.append(resultTypes->begin(), resultTypes->end()); + } + } } Operation *resultOp = rewriter.create(state); diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir @@ -127,6 +127,29 @@ // ----- +// CHECK-LABEL: module @operation_infer_types_from_interface +module @operation_infer_types_from_interface { + // Unused operation that ensures the arithmetic dialect is loaded for use in the pattern. + arith.constant true + + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter + // CHECK: %[[CST:.*]] = pdl_interp.create_operation "arith.constant" -> + // CHECK: %[[CST_RES:.*]] = pdl_interp.get_results of %[[CST]] : !pdl.range + // CHECK: %[[CST_TYPE:.*]] = pdl_interp.get_value_type of %[[CST_RES]] : !pdl.range + // CHECK: pdl_interp.create_operation "foo.op" -> (%[[CST_TYPE]] : !pdl.range) + pdl.pattern : benefit(1) { + %root = operation "foo.op" + rewrite %root { + %types = types + %newOp = operation "arith.constant" -> (%types : !pdl.range) + %newOp2 = operation "foo.op" -> (%types : !pdl.range) + } + } +} + +// ----- + // CHECK-LABEL: module @replace_with_op module @replace_with_op { // CHECK: module @rewriters diff --git a/mlir/test/Dialect/PDLInterp/invalid.mlir b/mlir/test/Dialect/PDLInterp/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/PDLInterp/invalid.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +//===----------------------------------------------------------------------===// +// pdl::CreateOperationOp +//===----------------------------------------------------------------------===// + +pdl_interp.func @rewriter() { + // expected-error@+1 {{op has inferred results, but the created operation 'foo.op' does not support result type inference}} + %op = pdl_interp.create_operation "foo.op" -> + pdl_interp.finalize +} + +// ----- + +pdl_interp.func @rewriter() { + %type = pdl_interp.create_type i32 + // expected-error@+1 {{op with inferred results cannot also have explicit result types}} + %op = "pdl_interp.create_operation"(%type) { + inferredResultTypes, + inputAttributeNames = [], + name = "foo.op", + operand_segment_sizes = dense<[0, 0, 1]> : vector<3xi32> + } : (!pdl.type) -> (!pdl.operation) + pdl_interp.finalize +} + diff --git a/mlir/test/Dialect/PDLInterp/ops.mlir b/mlir/test/Dialect/PDLInterp/ops.mlir --- a/mlir/test/Dialect/PDLInterp/ops.mlir +++ b/mlir/test/Dialect/PDLInterp/ops.mlir @@ -6,6 +6,10 @@ // ----- +// Unused operation to force loading the `arithmetic` dialect for the +// test of type inferrence. +arith.constant true + func.func @operations(%attribute: !pdl.attribute, %input: !pdl.value, %type: !pdl.type) { @@ -21,6 +25,9 @@ // operands, and results %op3 = pdl_interp.create_operation "foo.op"(%input : !pdl.value) -> (%type : !pdl.type) + // inferred results + %op4 = pdl_interp.create_operation "arith.constant" -> + pdl_interp.finalize } diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -531,6 +531,41 @@ // pdl_interp::CreateOperationOp //===----------------------------------------------------------------------===// +// Unused operation to force loading the `arithmetic` dialect for the +// test of type inferrence. +arith.constant 10 + +// Test support for inferring the types of an operation. +module @patterns { + pdl_interp.func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + pdl_interp.func @success(%root : !pdl.operation) { + %attr = pdl_interp.create_attribute true + %cst = pdl_interp.create_operation "arith.constant" {"value" = %attr} -> + %cstResults = pdl_interp.get_results of %cst : !pdl.range + %op = pdl_interp.create_operation "test.success"(%cstResults : !pdl.range) + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.create_op_infer_results +// CHECK: %[[CST:.*]] = arith.constant true +// CHECK: "test.success"(%[[CST]]) +module @ir attributes { test.create_op_infer_results } { + %results:2 = "test.op"() : () -> (i64, i64) +} + // ----- //===----------------------------------------------------------------------===// @@ -1181,12 +1216,6 @@ // Fully tested within the tests for other operations. -//===----------------------------------------------------------------------===// -// pdl_interp::InferredTypesOp -//===----------------------------------------------------------------------===// - -// Fully tested within the tests for other operations. - //===----------------------------------------------------------------------===// // pdl_interp::IsNotNullOp //===----------------------------------------------------------------------===//