diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -512,6 +512,13 @@ let arguments = (ins TypeArrayAttr:$value); let results = (outs PDL_RangeOf:$result); let assemblyFormat = "$value attr-dict"; + + let builders = [ + OpBuilder<(ins "ArrayAttr":$type), [{ + build($_builder, $_state, + pdl::RangeType::get($_builder.getType()), type); + }]> + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -236,10 +236,12 @@ return val; // Get the value for the parent position. - Value parentVal = getValueAt(currentBlock, pos->getParent()); + Value parentVal; + if (Position *parent = pos->getParent()) + parentVal = getValueAt(currentBlock, pos->getParent()); // TODO: Use a location from the position. - Location loc = parentVal.getLoc(); + Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc(); builder.setInsertionPointToEnd(currentBlock); Value value; switch (pos->getKind()) { @@ -330,6 +332,22 @@ parentVal, resPos->getResultGroupNumber()); break; } + case Predicates::AttributeLiteralPos: { + auto *attrPos = cast(pos); + value = + builder.create(loc, attrPos->getValue()); + break; + } + case Predicates::TypeLiteralPos: { + auto *typePos = cast(pos); + Attribute rawTypeAttr = typePos->getValue(); + if (TypeAttr typeAttr = rawTypeAttr.dyn_cast()) + value = builder.create(loc, typeAttr); + else + value = builder.create( + loc, rawTypeAttr.cast()); + break; + } default: llvm_unreachable("Generating unknown Position getter"); break; @@ -352,7 +370,7 @@ if (auto *equalToQuestion = dyn_cast(question)) { args = {getValueAt(currentBlock, equalToQuestion->getValue())}; } else if (auto *cstQuestion = dyn_cast(question)) { - for (Position *position : std::get<1>(cstQuestion->getValue())) + for (Position *position : cstQuestion->getArgs()) args.push_back(getValueAt(currentBlock, position)); } @@ -412,10 +430,10 @@ break; } case Predicates::ConstraintQuestion: { - auto value = cast(question)->getValue(); + auto *cstQuestion = cast(question); builder.create( - loc, std::get<0>(value), args, std::get<2>(value).cast(), - success, failure); + loc, cstQuestion->getName(), args, cstQuestion->getParams(), success, + failure); break; } default: diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -50,6 +50,8 @@ ResultPos, ResultGroupPos, TypePos, + AttributeLiteralPos, + TypeLiteralPos, // Questions, ordered by dependency and decreasing priority. IsNotNullQuestion, @@ -173,6 +175,16 @@ StringAttr getName() const { return key.second; } }; +//===----------------------------------------------------------------------===// +// AttributeLiteralPosition + +/// A position describing a literal attribute. +struct AttributeLiteralPosition + : public PredicateBase { + using PredicateBase::PredicateBase; +}; + //===----------------------------------------------------------------------===// // OperandPosition @@ -317,6 +329,17 @@ } }; +//===----------------------------------------------------------------------===// +// TypeLiteralPosition + +/// A position describing a literal type or type range. The value is stored as +/// either a TypeAttr, or an ArrayAttr of TypeAttr. +struct TypeLiteralPosition + : public PredicateBase { + using PredicateBase::PredicateBase; +}; + //===----------------------------------------------------------------------===// // Qualifiers //===----------------------------------------------------------------------===// @@ -404,6 +427,17 @@ Predicates::ConstraintQuestion> { using Base::Base; + /// Return the name of the constraint. + StringRef getName() const { return std::get<0>(key); } + + /// Return the arguments of the constraint. + ArrayRef getArgs() const { return std::get<1>(key); } + + /// Return the constant parameters of the constraint. + ArrayAttr getParams() const { + return std::get<2>(key).dyn_cast_or_null(); + } + /// Construct an instance with the given storage allocator. static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, KeyTy key) { @@ -461,12 +495,14 @@ PredicateUniquer() { // Register the types of Positions with the uniquer. registerParametricStorageType(); + registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); + registerParametricStorageType(); // Register the types of Questions with the uniquer. registerParametricStorageType(); @@ -527,6 +563,11 @@ return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name)); } + /// Returns an attribute position for the given attribute. + Position *getAttributeLiteral(Attribute attr) { + return AttributeLiteralPosition::get(uniquer, attr); + } + /// Returns an operand position for an operand of the given operation. Position *getOperand(OperationPosition *p, unsigned operand) { return OperandPosition::get(uniquer, p, operand); @@ -558,6 +599,12 @@ /// Returns a type position for the given entity. Position *getType(Position *p) { return TypePosition::get(uniquer, p); } + /// Returns a type position for the given type value. The value is stored + /// as either a TypeAttr, or an ArrayAttr of TypeAttr. + Position *getTypeLiteral(Attribute attr) { + return TypeLiteralPosition::get(uniquer, attr); + } + //===--------------------------------------------------------------------===// // Qualifiers //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp @@ -21,7 +21,7 @@ unsigned Position::getOperationDepth() const { if (const auto *operationPos = dyn_cast(this)) return operationPos->getDepth(); - return parent->getOperationDepth(); + return parent ? parent->getOperationDepth() : 0; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -243,8 +243,18 @@ .Default([](auto *) { llvm_unreachable("unexpected position kind"); }); } -/// Collect all of the predicates related to constraints within the given -/// pattern operation. +static void getAttributePredicates(pdl::AttributeOp op, + std::vector &predList, + PredicateBuilder &builder, + DenseMap &inputs) { + Position *&attrPos = inputs[op]; + if (attrPos) + return; + Attribute value = op.valueAttr(); + assert(value && "expected non-tree `pdl.attribute` to contain a value"); + attrPos = builder.getAttributeLiteral(value); +} + static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, std::vector &predList, PredicateBuilder &builder, @@ -296,6 +306,19 @@ predList.emplace_back(resultPos, builder.getIsNotNull()); } +static void getTypePredicates(Value typeValue, + function_ref typeAttrFn, + PredicateBuilder &builder, + DenseMap &inputs) { + Position *&typePos = inputs[typeValue]; + if (typePos) + return; + Attribute typeAttr = typeAttrFn(); + assert(typeAttr && + "expected non-tree `pdl.type`/`pdl.types` to contain a value"); + typePos = builder.getTypeLiteral(typeAttr); +} + /// Collect all of the predicates that cannot be determined via walking the /// tree. static void getNonTreePredicates(pdl::PatternOp pattern, @@ -304,11 +327,22 @@ DenseMap &inputs) { for (Operation &op : pattern.body().getOps()) { TypeSwitch(&op) + .Case([&](pdl::AttributeOp attrOp) { + getAttributePredicates(attrOp, predList, builder, inputs); + }) .Case([&](auto constraintOp) { getConstraintPredicates(constraintOp, predList, builder, inputs); }) .Case([&](auto resultOp) { getResultPredicates(resultOp, predList, builder, inputs); + }) + .Case([&](pdl::TypeOp typeOp) { + getTypePredicates( + typeOp, [&] { return typeOp.typeAttr(); }, builder, inputs); + }) + .Case([&](pdl::TypesOp typeOp) { + getTypePredicates( + typeOp, [&] { return typeOp.typesAttr(); }, builder, inputs); }); } } diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -114,12 +114,15 @@ Value attrType = op.type(); Optional attrValue = op.value(); - if (!attrValue && isa(op->getParentOp())) - return op.emitOpError("expected constant value when specified within a " - "`pdl.rewrite`"); - if (attrValue && attrType) + if (!attrValue) { + if (isa(op->getParentOp())) + return op.emitOpError("expected constant value when specified within a " + "`pdl.rewrite`"); + return verifyHasBindingUse(op); + } + if (attrType) return op.emitOpError("expected only one of [`type`, `value`] to be set"); - return verifyHasBindingUse(op); + return success(); } //===----------------------------------------------------------------------===// @@ -431,13 +434,21 @@ // pdl::TypeOp //===----------------------------------------------------------------------===// -static LogicalResult verify(TypeOp op) { return verifyHasBindingUse(op); } +static LogicalResult verify(TypeOp op) { + if (!op.typeAttr()) + return verifyHasBindingUse(op); + return success(); +} //===----------------------------------------------------------------------===// // pdl::TypesOp //===----------------------------------------------------------------------===// -static LogicalResult verify(TypesOp op) { return verifyHasBindingUse(op); } +static LogicalResult verify(TypesOp op) { + if (!op.typesAttr()) + return verifyHasBindingUse(op); + return success(); +} //===----------------------------------------------------------------------===// // TableGen'd op method definitions diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -573,3 +573,42 @@ pdl.rewrite with "rewriter"(%root1, %root2 : !pdl.operation, !pdl.operation) } } + +// ----- + +// CHECK-LABEL: module @attribute_literal +module @attribute_literal { + // CHECK: func @matcher(%{{.*}}: !pdl.operation) + // CHECK: %[[ATTR:.*]] = pdl_interp.create_attribute 10 : i64 + // CHECK: pdl_interp.apply_constraint "constraint"(%[[ATTR]] : !pdl.attribute) + + // Check the correct lowering of an attribute that hasn't been bound. + pdl.pattern : benefit(1) { + %attr = pdl.attribute 10 + pdl.apply_native_constraint "constraint"(%attr: !pdl.attribute) + + %root = pdl.operation + pdl.rewrite %root with "rewriter" + } +} + +// ----- + +// CHECK-LABEL: module @type_literal +module @type_literal { + // CHECK: func @matcher(%{{.*}}: !pdl.operation) + // CHECK: %[[TYPE:.*]] = pdl_interp.create_type i32 + // CHECK: %[[TYPES:.*]] = pdl_interp.create_types [i32, i64] + // CHECK: pdl_interp.apply_constraint "constraint"(%[[TYPE]], %[[TYPES]] : !pdl.type, !pdl.range) + + // Check the correct lowering of a type that hasn't been bound. + pdl.pattern : benefit(1) { + %type = pdl.type : i32 + %types = pdl.types : [i32, i64] + pdl.apply_native_constraint "constraint"(%type, %types: !pdl.type, !pdl.range) + + %root = pdl.operation + pdl.rewrite %root with "rewriter" + } +} +