diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -436,6 +436,48 @@ let hasRegionVerifier = 1; } +//===----------------------------------------------------------------------===// +// pdl::RangeOp +//===----------------------------------------------------------------------===// + +def PDL_RangeOp : PDL_Op<"range", [NoSideEffect, HasParent<"pdl::RewriteOp">]> { + let summary = "Construct a range of pdl entities"; + let description = [{ + `pdl.range` operations construct a range from a given set of PDL entities, + which all share the same underlying element type. For example, a + `!pdl.range` may be constructed from a list of `!pdl.value` + or `!pdl.range` entities. + + Example: + + ```mlir + // Construct a range of values. + %valueRange = pdl.range %inputValue, %inputRange : !pdl.value, !pdl.range + + // Construct a range of types. + %typeRange = pdl.range %inputType, %inputRange : !pdl.type, !pdl.range + + // Construct an empty range of types. + %valueRange = pdl.range : !pdl.range + ``` + + TODO: Range construction is currently limited to rewrites, but it could + be extended to constraints under certain circustances; i.e., if we can + determine how to extract the underlying elements. If we can't, e.g. if + there are multiple sub ranges used for construction, we won't be able + to determine their sizes during constraint time. + }]; + + let arguments = (ins Variadic:$arguments); + let results = (outs PDL_RangeOf>:$result); + let assemblyFormat = [{ + ($arguments^ `:` type($arguments))? + custom(ref(type($arguments)), type($result)) + attr-dict + }]; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // pdl::ReplaceOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h --- a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h @@ -28,6 +28,11 @@ static bool classof(Type type); }; + +/// If the given type is a range, return its element type, otherwise return +/// the type itself. +Type getRangeElementTypeOrSelf(Type type); + } // namespace pdl } // namespace mlir diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -992,6 +992,43 @@ let assemblyFormat = "$value `:` type($value) attr-dict `->` successors"; } + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateRangeOp +//===----------------------------------------------------------------------===// + +def PDLInterp_CreateRangeOp : PDLInterp_Op<"create_range", [NoSideEffect]> { + let summary = "Construct a range of PDL entities"; + let description = [{ + `pdl_interp.create_range` operations construct a range from a given set of PDL + entities, which all share the same underlying element type. For example, a + `!pdl.range` may be constructed from a list of `!pdl.value` + or `!pdl.range` entities. + + Example: + + ```mlir + // Construct a range of values. + %valueRange = pdl_interp.create_range %inputValue, %inputRange : !pdl.value, !pdl.range + + // Construct a range of types. + %typeRange = pdl_interp.create_range %inputType, %inputRange : !pdl.type, !pdl.range + + // Construct an empty range of types. + %valueRange = pdl_interp.create_range : !pdl.range + ``` + }]; + + let arguments = (ins Variadic:$arguments); + let results = (outs PDL_RangeOf>:$result); + let assemblyFormat = [{ + ($arguments^ `:` type($arguments))? + custom(ref(type($arguments)), type($result)) + attr-dict + }]; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // pdl_interp::RecordMatchOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -89,6 +89,9 @@ void generateRewriter(pdl::OperationOp operationOp, DenseMap &rewriteValues, function_ref mapRewriteValue); + void generateRewriter(pdl::RangeOp rangeOp, + DenseMap &rewriteValues, + function_ref mapRewriteValue); void generateRewriter(pdl::ReplaceOp replaceOp, DenseMap &rewriteValues, function_ref mapRewriteValue); @@ -666,8 +669,8 @@ for (Operation &rewriteOp : *rewriter.getBody()) { llvm::TypeSwitch(&rewriteOp) .Case([&](auto op) { + pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp, + pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) { this->generateRewriter(op, rewriteValues, mapRewriteValue); }); } @@ -773,6 +776,16 @@ } } +void PatternLowering::generateRewriter( + pdl::RangeOp rangeOp, DenseMap &rewriteValues, + function_ref mapRewriteValue) { + SmallVector replOperands; + for (Value operand : rangeOp.arguments()) + replOperands.push_back(mapRewriteValue(operand)); + rewriteValues[rangeOp] = builder.create( + rangeOp.getLoc(), rangeOp.getType(), replOperands); +} + void PatternLowering::generateRewriter( pdl::ReplaceOp replaceOp, DenseMap &rewriteValues, function_ref mapRewriteValue) { diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -395,6 +395,39 @@ return PDLDialect::getDialectNamespace(); } +//===----------------------------------------------------------------------===// +// pdl::RangeOp +//===----------------------------------------------------------------------===// + +static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes, + Type &resultType) { + // If arguments were provided, infer the result type from the argument list. + if (!argumentTypes.empty()) { + resultType = RangeType::get(getRangeElementTypeOrSelf(argumentTypes[0])); + return success(); + } + // Otherwise, parse the type as a trailing type. + return p.parseColonType(resultType); +} + +static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes, + Type resultType) { + if (argumentTypes.empty()) + p << ": " << resultType; +} + +LogicalResult RangeOp::verify() { + Type elementType = getType().getElementType(); + for (Type operandType : getOperandTypes()) { + Type operandElementType = getRangeElementTypeOrSelf(operandType); + if (operandElementType != elementType) { + return emitOpError("expected operand to have element type ") + << elementType << ", but got " << operandElementType; + } + } + return success(); +} + //===----------------------------------------------------------------------===// // pdl::ReplaceOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp --- a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp @@ -59,6 +59,12 @@ return llvm::isa(type.getDialect()); } +Type pdl::getRangeElementTypeOrSelf(Type type) { + if (auto rangeType = type.dyn_cast()) + return rangeType.getElementType(); + return type; +} + //===----------------------------------------------------------------------===// // RangeType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -237,6 +237,40 @@ return type.isa() ? pdl::RangeType::get(valueTy) : valueTy; } +//===----------------------------------------------------------------------===// +// pdl::CreateRangeOp +//===----------------------------------------------------------------------===// + +static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes, + Type &resultType) { + // If arguments were provided, infer the result type from the argument list. + if (!argumentTypes.empty()) { + resultType = + pdl::RangeType::get(pdl::getRangeElementTypeOrSelf(argumentTypes[0])); + return success(); + } + // Otherwise, parse the type as a trailing type. + return p.parseColonType(resultType); +} + +static void printRangeType(OpAsmPrinter &p, CreateRangeOp op, + TypeRange argumentTypes, Type resultType) { + if (argumentTypes.empty()) + p << ": " << resultType; +} + +LogicalResult CreateRangeOp::verify() { + Type elementType = getType().getElementType(); + for (Type operandType : getOperandTypes()) { + Type operandElementType = pdl::getRangeElementTypeOrSelf(operandType); + if (operandElementType != elementType) { + return emitOpError("expected operand to have element type ") + << elementType << ", but got " << operandElementType; + } + } + return success(); +} + //===----------------------------------------------------------------------===// // pdl_interp::SwitchAttributeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -99,10 +99,14 @@ CheckTypes, /// Continue to the next iteration of a loop. Continue, + /// Create a type range from a list of constant types. + CreateConstantTypeRange, /// Create an operation. CreateOperation, - /// Create a range of types. - CreateTypes, + /// Create a type range from a list of dynamic types. + CreateTypeRange, + /// Create a value range. + CreateValueRange, /// Erase an operation. EraseOp, /// Extract the op from a range at the specified index. @@ -265,6 +269,7 @@ void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); @@ -742,9 +747,9 @@ pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp, - pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, - pdl_interp::CreateTypesOp, pdl_interp::EraseOp, - pdl_interp::ExtractOp, pdl_interp::FinalizeOp, + pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp, + pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp, + pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp, pdl_interp::ForEachOp, pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, @@ -863,12 +868,21 @@ else writer.appendPDLValueList(op.getInputResultTypes()); } +void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) { + // Append the correct opcode for the range type. + TypeSwitch(op.getType().getElementType()) + .Case([&](pdl::TypeType) { writer.append(OpCode::CreateTypeRange); }) + .Case([&](pdl::ValueType) { writer.append(OpCode::CreateValueRange); }); + + writer.append(op.getResult(), getRangeStorageIndex(op.getResult())); + writer.appendPDLValueList(op->getOperands()); +} void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { // Simply repoint the memory index of the result to the constant. getMemIndex(op.getResult()) = getMemIndex(op.getValue()); } void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { - writer.append(OpCode::CreateTypes, op.getResult(), + writer.append(OpCode::CreateConstantTypeRange, op.getResult(), getRangeStorageIndex(op.getResult()), op.getValue()); } void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { @@ -1103,9 +1117,11 @@ void executeCheckResultCount(); void executeCheckTypes(); void executeContinue(); + void executeCreateConstantTypeRange(); void executeCreateOperation(PatternRewriter &rewriter, Location mainRewriteLoc); - void executeCreateTypes(); + template + void executeCreateRange(StringRef type); void executeEraseOp(PatternRewriter &rewriter); template void executeExtract(); @@ -1172,8 +1188,18 @@ } /// Read a list of values from the bytecode buffer. The values may be encoded - /// as either Value or ValueRange elements. - void readValueList(SmallVectorImpl &list) { + /// either as a single element or a range of elements. + void readList(SmallVectorImpl &list) { + for (unsigned i = 0, e = read(); i != e; ++i) { + if (read() == PDLValue::Kind::Type) { + list.push_back(read()); + } else { + TypeRange *values = read(); + list.append(values->begin(), values->end()); + } + } + } + void readList(SmallVectorImpl &list) { for (unsigned i = 0, e = read(); i != e; ++i) { if (read() == PDLValue::Kind::Value) { list.push_back(read()); @@ -1292,6 +1318,39 @@ return static_cast(readImpl()); } + /// Assign the given range to the given memory index. This allocates a new + /// range object if necessary. + template > + void assignRangeToMemory(RangeT &&range, unsigned memIndex, + unsigned rangeIndex) { + // Utility functor used to type-erase the assignment. + auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) { + // If the input range is empty, we don't need to allocate anything. + if (llvm::empty(range)) { + rangeMemory[rangeIndex] = {}; + } else { + // Allocate a buffer for this type range. + llvm::OwningArrayRef storage(llvm::size(range)); + llvm::copy(range, storage.begin()); + + // Assign this to the range slot and use the range as the value for the + // memory index. + allocatedRangeMemory.emplace_back(std::move(storage)); + rangeMemory[rangeIndex] = allocatedRangeMemory.back(); + } + memory[memIndex] = &rangeMemory[rangeIndex]; + }; + + // Dispatch based on the concrete range type. + if constexpr (std::is_same_v) { + return assignRange(allocatedTypeRangeMemory, typeRangeMemory); + } else if constexpr (std::is_same_v) { + return assignRange(allocatedValueRangeMemory, valueRangeMemory); + } else { + llvm_unreachable("unhandled range type"); + } + } + /// The underlying bytecode buffer. const ByteCodeField *curCodeIt; @@ -1514,23 +1573,15 @@ popCodeIt(); } -void ByteCodeExecutor::executeCreateTypes() { - LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); +void ByteCodeExecutor::executeCreateConstantTypeRange() { + LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n"); unsigned memIndex = read(); unsigned rangeIndex = read(); ArrayAttr typesAttr = read().cast(); LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); - - // Allocate a buffer for this type range. - llvm::OwningArrayRef storage(typesAttr.size()); - llvm::copy(typesAttr.getAsValueRange(), storage.begin()); - allocatedTypeRangeMemory.emplace_back(std::move(storage)); - - // Assign this to the range slot and use the range as the value for the - // memory index. - typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); - memory[memIndex] = &typeRangeMemory[rangeIndex]; + assignRangeToMemory(typesAttr.getAsValueRange(), memIndex, + rangeIndex); } void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, @@ -1539,7 +1590,7 @@ unsigned memIndex = read(); OperationState state(mainRewriteLoc, read()); - readValueList(state.operands); + readList(state.operands); for (unsigned i = 0, e = read(); i != e; ++i) { StringAttr name = read(); if (Attribute attr = read()) @@ -1587,6 +1638,23 @@ }); } +template +void ByteCodeExecutor::executeCreateRange(StringRef type) { + LLVM_DEBUG(llvm::dbgs() << "Executing Create" << type << "Range:\n"); + unsigned memIndex = read(); + unsigned rangeIndex = read(); + SmallVector values; + readList(values); + + LLVM_DEBUG({ + llvm::dbgs() << "\n * " << type << "s: "; + llvm::interleaveComma(values, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + assignRangeToMemory(values, memIndex, rangeIndex); +} + void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); Operation *op = read(); @@ -1949,7 +2017,7 @@ LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); Operation *op = read(); SmallVector args; - readValueList(args); + readList(args); LLVM_DEBUG({ llvm::dbgs() << " * Operation: " << *op << "\n" @@ -2076,11 +2144,17 @@ case Continue: executeContinue(); break; + case CreateConstantTypeRange: + executeCreateConstantTypeRange(); + break; case CreateOperation: executeCreateOperation(rewriter, *mainRewriteLoc); break; - case CreateTypes: - executeCreateTypes(); + case CreateTypeRange: + executeCreateRange("Type"); + break; + case CreateValueRange: + executeCreateRange("Value"); break; case EraseOp: executeEraseOp(rewriter); diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir @@ -243,3 +243,20 @@ } // ----- + +// CHECK-LABEL: module @range_op +module @range_op { + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter(%[[OPERAND:.*]]: !pdl.value) + // CHECK: %[[RANGE1:.*]] = pdl_interp.create_range : !pdl.range + // CHECK: %[[RANGE2:.*]] = pdl_interp.create_range %[[OPERAND]], %[[RANGE1]] : !pdl.value, !pdl.range + // CHECK: pdl_interp.finalize + pdl.pattern : benefit(1) { + %operand = pdl.operand + %root = operation "foo.op"(%operand : !pdl.value) + rewrite %root { + %emptyRange = pdl.range : !pdl.range + %range = pdl.range %operand, %emptyRange : !pdl.value, !pdl.range + } + } +} diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir --- a/mlir/test/Dialect/PDL/invalid.mlir +++ b/mlir/test/Dialect/PDL/invalid.mlir @@ -237,6 +237,23 @@ // ----- +//===----------------------------------------------------------------------===// +// pdl::RangeOp +//===----------------------------------------------------------------------===// + +pdl.pattern : benefit(1) { + %operand = pdl.operand + %resultType = pdl.type + %root = pdl.operation "baz.op"(%operand : !pdl.value) -> (%resultType : !pdl.type) + + rewrite %root { + // expected-error @below {{expected operand to have element type '!pdl.value', but got '!pdl.type'}} + %range = pdl.range %operand, %resultType : !pdl.value, !pdl.type + } +} + +// ----- + //===----------------------------------------------------------------------===// // pdl::ResultsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/PDLInterp/invalid.mlir b/mlir/test/Dialect/PDLInterp/invalid.mlir --- a/mlir/test/Dialect/PDLInterp/invalid.mlir +++ b/mlir/test/Dialect/PDLInterp/invalid.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics //===----------------------------------------------------------------------===// -// pdl::CreateOperationOp +// pdl_interp::CreateOperationOp //===----------------------------------------------------------------------===// pdl_interp.func @rewriter() { @@ -23,3 +23,15 @@ } : (!pdl.type) -> (!pdl.operation) pdl_interp.finalize } + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateRangeOp +//===----------------------------------------------------------------------===// + +pdl_interp.func @rewriter(%value: !pdl.value, %type: !pdl.type) { + // expected-error @below {{expected operand to have element type '!pdl.value', but got '!pdl.type'}} + %range = pdl_interp.create_range %value, %type : !pdl.value, !pdl.type + pdl_interp.finalize +} diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -568,6 +568,48 @@ // ----- +//===----------------------------------------------------------------------===// +// pdl_interp::CreateRangeOp +//===----------------------------------------------------------------------===// + +module @patterns { + pdl_interp.func @matcher(%root : !pdl.operation) { + pdl_interp.check_operand_count of %root is 2 -> ^pat1, ^end + + ^pat1: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + pdl_interp.func @success(%root: !pdl.operation) { + %rootOperand = pdl_interp.get_operand 0 of %root + %rootOperands = pdl_interp.get_operands of %root : !pdl.range + %operandRange = pdl_interp.create_range %rootOperand, %rootOperands : !pdl.value, !pdl.range + + %operandType = pdl_interp.get_value_type of %rootOperand : !pdl.type + %operandTypes = pdl_interp.get_value_type of %rootOperands : !pdl.range + %typeRange = pdl_interp.create_range %operandType, %operandTypes : !pdl.type, !pdl.range + + %op = pdl_interp.create_operation "test.success"(%operandRange : !pdl.range) -> (%typeRange : !pdl.range) + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.create_range_1 +// CHECK: %[[INPUTS:.*]]:2 = "test.input"() +// CHECK: "test.success"(%[[INPUTS]]#0, %[[INPUTS]]#0, %[[INPUTS]]#1) : (i32, i32, i32) -> (i32, i32, i32) +module @ir attributes { test.create_range_1 } { + %values:2 = "test.input"() : () -> (i32, i32) + "test.op"(%values#0, %values#1) : (i32, i32) -> () +} + +// ----- + //===----------------------------------------------------------------------===// // pdl_interp::CreateTypeOp //===----------------------------------------------------------------------===//