diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h @@ -0,0 +1,40 @@ +//===- PDLInterp.h - PDL Interpreter dialect --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the interpreter dialect for the PDL pattern descriptor +// language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_ +#define MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +namespace mlir { +namespace pdl_interp { +//===----------------------------------------------------------------------===// +// PDLInterp Dialect +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/PDLInterp/IR/PDLInterpDialect.h.inc" + +//===----------------------------------------------------------------------===// +// PDLInterp Dialect Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.h.inc" + +} // end namespace pdl_interp +} // end namespace mlir + +#endif // MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_ diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -0,0 +1,900 @@ +//===- PDLInterpOps.td - Pattern Interpreter Dialect -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the PDL interpreter dialect ops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS +#define MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS + +include "mlir/Dialect/PDL/IR/PDLBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// PDLInterp Dialect +//===----------------------------------------------------------------------===// + +def PDLInterp_Dialect : Dialect { + let summary = "Interpreted pattern execution dialect"; + let description = [{ + The PDL Interpreter dialect provides a lower level abstraction compared to + the PDL dialect, and is targeted towards low level optimization and + interpreter code generation. The dialect operations encapsulates + low-level pattern match and rewrite "primitives", such as navigating the + IR (Operation::getOperand), creating new operations (OpBuilder::create), + etc. Many of the operations within this dialect also fuse branching control + flow with some form of a predicate comparison operation. This type of fusion + reduces the amount of work that an interpreter must do when executing. + }]; + + let name = "pdl_interp"; + let cppNamespace = "mlir::pdl_interp"; +} + +//===----------------------------------------------------------------------===// +// PDLInterp Operations +//===----------------------------------------------------------------------===// + +// Generic interpreter operation. +class PDLInterp_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// PDLInterp_PredicateOp + +// Check operations evaluate a predicate on a positional value and then +// conditionally branch on the result. +class PDLInterp_PredicateOp traits = []> : + PDLInterp_Op { + let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); +} + +//===----------------------------------------------------------------------===// +// PDLInterp_SwitchOp + +// Switch operations evaluate a predicate on a positional value and then +// conditionally branch on the result. +class PDLInterp_SwitchOp traits = []> : + PDLInterp_Op { + let successors = (successor AnySuccessor:$defaultDest, + VariadicSuccessor:$cases); + + let verifier = [{ + // Verify that the number of case destinations matches the number of case + // values. + size_t numDests = cases().size(); + size_t numValues = caseValues().size(); + if (numDests != numValues) { + return emitOpError("expected number of cases to match the number of case " + "values, got ") + << numDests << " but expected " << numValues; + } + return success(); + }]; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::ApplyConstraintOp +//===----------------------------------------------------------------------===// + +def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> { + let summary = "Apply a constraint to a set of positional values"; + let description = [{ + `pdl_interp.apply_generic_constraint` operations apply a generic constraint, + that has been registered with the interpreter, with a given set of + positional values. The constraint may have any number of constant + parameters. On success, this operation branches to the true destination, + otherwise the false destination is taken. + + Example: + + ```mlir + // Apply `myConstraint` to the entities defined by `input`, `attr`, and + // `op`. + pdl_interp.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) + ``` + }]; + + let arguments = (ins Variadic:$values, + StrAttr:$name, + ArrayAttr:$params); + let assemblyFormat = [{ + $name $params `(` $values `:` type($values) `)` attr-dict `->` successors + }]; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::ApplyRewriteOp +//===----------------------------------------------------------------------===// + +def PDLInterp_ApplyRewriteOp : PDLInterp_Op<"apply_rewrite"> { + let summary = "Invoke and apply an externally registered rewrite method"; + let description = [{ + `pdl_interp.apply_rewrite` operations invoke an external rewriter that has + been registered with the interpreter to perform the rewrite after a + successful match. + + Example: + + ```mlir + pdl_interp.apply_rewrite "rewriter" + ``` + }]; + let arguments = (ins StrAttr:$name); + let assemblyFormat = "$name attr-dict"; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::AreEqualOp +//===----------------------------------------------------------------------===// + +def PDLInterp_AreEqualOp : PDLInterp_PredicateOp<"are_equal", + [SameTypeOperands]> { + let summary = "Check if two positional values are equivalent"; + let description = [{ + `pdl_interp.are_equal` operations compare two positional values for + equivalence. On success, this operation branches to the true destination, + otherwise the false destination is taken. + + Example: + + ```mlir + pdl_interp.are_equal %result1, %result2 : !pdl.value -> ^matchDest, ^failureDest + ``` + }]; + + let arguments = (ins PDL_PositionalValue:$lhs, + PDL_PositionalValue:$rhs); + let assemblyFormat = "operands `:` type($lhs) attr-dict `->` successors"; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::BranchOp +//===----------------------------------------------------------------------===// + +def PDLInterp_BranchOp : PDLInterp_Op<"branch", [Terminator]> { + let summary = "General branch operation"; + let description = [{ + `pdl_interp.branch` operations expose general branch functionality to the + interpreter, and are generally used to branch from one pattern match + sequence to another. + + Example: + + ```mlir + pdl_interp.branch ^dest + ``` + }]; + + let successors = (successor AnySuccessor:$dest); + let assemblyFormat = "$dest attr-dict"; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::CheckAttributeOp +//===----------------------------------------------------------------------===// + +def PDLInterp_CheckAttributeOp : PDLInterp_PredicateOp<"check_attribute"> { + let summary = "Check the value of an `Attribute`"; + let description = [{ + `pdl_interp.check_attribute` operations compare the value of a given + attribute with a constant value. On success, this operation branches to the + true destination, otherwise the false destination is taken. + + Example: + + ```mlir + pdl_interp.check_attribute %attr is 10 -> ^matchDest, ^failureDest + ``` + }]; + + let arguments = (ins PDL_Attribute:$attribute, AnyAttr:$constantValue); + let assemblyFormat = [{ + $attribute `is` $constantValue attr-dict `->` successors + }]; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::CheckOperandCountOp +//===----------------------------------------------------------------------===// + +def PDLInterp_CheckOperandCountOp + : PDLInterp_PredicateOp<"check_operand_count"> { + let summary = "Check the number of operands of an `Operation`"; + let description = [{ + `pdl_interp.check_operand_count` operations compare the number of operands + of a given operation value with a constant. On success, this operation + branches to the true destination, otherwise the false destination is taken. + + Example: + + ```mlir + pdl_interp.check_operand_count of %op is 2 -> ^matchDest, ^failureDest + ``` + }]; + + let arguments = (ins PDL_Operation:$operation, + Confined:$count); + let assemblyFormat = "`of` $operation `is` $count attr-dict `->` successors"; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::CheckOperationNameOp +//===----------------------------------------------------------------------===// + +def PDLInterp_CheckOperationNameOp + : PDLInterp_PredicateOp<"check_operation_name"> { + let summary = "Check the OperationName of an `Operation`"; + let description = [{ + `pdl_interp.check_operation_name` operations compare the name of a given + operation with a known name. On success, this operation branches to the true + destination, otherwise the false destination is taken. + + Example: + + ```mlir + pdl_interp.check_operation_name of %op is "foo.op" -> ^matchDest, ^failureDest + ``` + }]; + + let arguments = (ins PDL_Operation:$operation, StrAttr:$name); + let assemblyFormat = "`of` $operation `is` $name attr-dict `->` successors"; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::CheckResultCountOp +//===----------------------------------------------------------------------===// + +def PDLInterp_CheckResultCountOp : PDLInterp_PredicateOp<"check_result_count"> { + let summary = "Check the number of results of an `Operation`"; + let description = [{ + `pdl_interp.check_result_count` operations compare the number of results + of a given operation value with a constant. On success, this operation + branches to the true destination, otherwise the false destination is taken. + + Example: + + ```mlir + pdl_interp.check_result_count of %op is 0 -> ^matchDest, ^failureDest + ``` + }]; + + let arguments = (ins PDL_Operation:$operation, + Confined:$count); + let assemblyFormat = "`of` $operation `is` $count attr-dict `->` successors"; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::CheckTypeOp +//===----------------------------------------------------------------------===// + +def PDLInterp_CheckTypeOp : PDLInterp_PredicateOp<"check_type"> { + let summary = "Compare a type to a known value"; + let description = [{ + `pdl_interp.check_type` operations compare a type with a statically known + type. On success, this operation branches to the true destination, otherwise + the false destination is taken. + + Example: + + ```mlir + pdl_interp.check_type %type is 0 -> ^matchDest, ^failureDest + ``` + }]; + + let arguments = (ins PDL_Type:$value, TypeAttr:$type); + let assemblyFormat = "$value `is` $type attr-dict `->` successors"; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateAttributeOp +//===----------------------------------------------------------------------===// + +def PDLInterp_CreateAttributeOp + : PDLInterp_Op<"create_attribute", [NoSideEffect]> { + let summary = "Create an interpreter handle to a constant `Attribute`"; + let description = [{ + `pdl_interp.create_attribute` operations generate a handle within the + interpreter for a specific constant attribute value. + + Example: + + ```mlir + pdl_interp.create_attribute 10 : i64 + ``` + }]; + + let arguments = (ins AnyAttr:$value); + let results = (outs PDL_Attribute:$attribute); + let assemblyFormat = "$value attr-dict"; + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, " + "Attribute value", [{ + build(builder, state, builder.getType(), value); + }]>]; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateNativeOp +//===----------------------------------------------------------------------===// + +def PDLInterp_CreateNativeOp : PDLInterp_Op<"create_native"> { + let summary = "Call a native creation method to construct an `Attribute`, " + "`Operation`, `Type`, or `Value`"; + let description = [{ + `pdl_interp.create_native` operations invoke a native C++ function, that has + been registered externally with the consumer of PDL, to create an + `Attribute`, `Operation`, `Type`, or `Value`. The native function must + produce a value of the specified return type, and may accept any number of + positional arguments and constant attribute parameters. + + Example: + + ```mlir + %ret = pdl_interp.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute + ``` + }]; + + let arguments = (ins StrAttr:$name, Variadic:$arguments, + ArrayAttr:$constantParams); + let results = (outs PDL_PositionalValue:$result); + let assemblyFormat = [{ + $name $constantParams (`(` $arguments^ `:` type($arguments) `)`)? + `:` type($result) attr-dict + }]; + let verifier = ?; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateOperationOp +//===----------------------------------------------------------------------===// + +def PDLInterp_CreateOperationOp + : PDLInterp_Op<"create_operation", [AttrSizedOperandSegments]> { + let summary = "Create an instance of a specific `Operation`"; + let description = [{ + `pdl_interp.create_operation` operations create an `Operation` instance with + the specified attributes, operands, and result types. + + Example: + + ```mlir + // Create an instance of a `foo.op` operation. + %op = pdl_interp.create_operation "foo.op"(%arg0) {"attrA" = %attr0} -> %type, %type + ``` + }]; + + let arguments = (ins StrAttr:$name, + Variadic:$operands, + Variadic:$attributes, + StrArrayAttr:$attributeNames, + Variadic:$types); + let results = (outs PDL_Operation:$operation); + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, StringRef name, " + "ValueRange types, ValueRange operands, ValueRange attributes, " + "ArrayAttr attributeNames", [{ + build(builder, state, builder.getType(), name, + operands, attributes, attributeNames, types); + }]>]; + let parser = [{ return ::parseCreateOperationOp(parser, result); }]; + let printer = [{ ::print(p, *this); }]; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateTypeOp +//===----------------------------------------------------------------------===// + +def PDLInterp_CreateTypeOp : PDLInterp_Op<"create_type", [NoSideEffect]> { + let summary = "Create an interpreter handle to a constant `Type`"; + let description = [{ + `pdl_interp.create_type` operations generate a handle within the interpreter + for a specific constant type value. + + Example: + + ```mlir + pdl_interp.create_type i64 + ``` + }]; + + let arguments = (ins TypeAttr:$value); + let results = (outs PDL_Type:$result); + let assemblyFormat = "$value attr-dict"; + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, TypeAttr type", [{ + build(builder, state, builder.getType(), type); + }]> + ]; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::EraseOp +//===----------------------------------------------------------------------===// + +def PDLInterp_EraseOp : PDLInterp_Op<"erase"> { + let summary = "Mark an operation as `erased`"; + let description = [{ + `pdl.erase` operations are used to specify that an operation should be + marked as erased. The semantics of this operation correspond with the + `eraseOp` method on a `PatternRewriter`. + + Example: + + ```mlir + pdl_interp.erase %root + ``` + }]; + + let arguments = (ins PDL_Operation:$operation); + let assemblyFormat = "$operation attr-dict"; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::GetAttributeOp +//===----------------------------------------------------------------------===// + +def PDLInterp_GetAttributeOp : PDLInterp_Op<"get_attribute", [NoSideEffect]> { + let summary = "Get a specified attribute value from an `Operation`"; + let description = [{ + `pdl_interp.get_attribute` operations try to get a specific attribute from + an operation. If the operation does not have that attribute, a null value is + returned. + + Example: + + ```mlir + pdl_interp.get_attribute "attr" of %op + ``` + }]; + + let arguments = (ins PDL_Operation:$operation, + StrAttr:$name); + let results = (outs PDL_Attribute:$attribute); + let assemblyFormat = "$name `of` $operation attr-dict"; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::GetAttributeTypeOp +//===----------------------------------------------------------------------===// + +def PDLInterp_GetAttributeTypeOp + : PDLInterp_Op<"get_attribute_type", [NoSideEffect]> { + let summary = "Get the result type of a specified `Attribute`"; + let description = [{ + `pdl_interp.get_attribute_type` operations get the resulting type of a + specific attribute. + + Example: + + ```mlir + pdl_interp.get_attribute_type of %attr + ``` + }]; + + let arguments = (ins PDL_Attribute:$value); + let results = (outs PDL_Type:$result); + let assemblyFormat = "`of` $value attr-dict"; + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value value", [{ + build(builder, state, builder.getType(), value); + }]> + ]; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::GetDefiningOpOp +//===----------------------------------------------------------------------===// + +def PDLInterp_GetDefiningOpOp + : PDLInterp_Op<"get_defining_op", [NoSideEffect]> { + let summary = "Get the defining operation of a `Value`"; + let description = [{ + `pdl_interp.get_defining_op` operations try to get the defining operation + of a specific value. If the value is not an operation result, null is + returned. + + Example: + + ```mlir + pdl_interp.get_defining_op of %value + ``` + }]; + + let arguments = (ins PDL_Value:$value); + let results = (outs PDL_Operation:$operation); + let assemblyFormat = "`of` $value attr-dict"; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::GetOperandOp +//===----------------------------------------------------------------------===// + +def PDLInterp_GetOperandOp : PDLInterp_Op<"get_operand", [NoSideEffect]> { + let summary = "Get a specified operand from an `Operation`"; + let description = [{ + `pdl_interp.get_operand` operations try to get a specific operand from an + operation If the operation does not have an operand for the given index, a + null value is returned. + + Example: + + ```mlir + pdl_interp.get_operand 1 of %op + ``` + }]; + + let arguments = (ins PDL_Operation:$operation, + Confined:$index); + let results = (outs PDL_Value:$value); + let assemblyFormat = "$index `of` $operation attr-dict"; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::GetResultOp +//===----------------------------------------------------------------------===// + +def PDLInterp_GetResultOp : PDLInterp_Op<"get_result", [NoSideEffect]> { + let summary = "Get a specified result from an `Operation`"; + let description = [{ + `pdl_interp.get_result` operations try to get a specific result from an + operation. If the operation does not have a result for the given index, a + null value is returned. + + Example: + + ```mlir + pdl_interp.get_result 1 of %op + ``` + }]; + + let arguments = (ins PDL_Operation:$operation, + Confined:$index); + let results = (outs PDL_Value:$value); + let assemblyFormat = "$index `of` $operation attr-dict"; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::GetValueTypeOp +//===----------------------------------------------------------------------===// + +// Get a type from the root operation, held in the rewriter context. +def PDLInterp_GetValueTypeOp : PDLInterp_Op<"get_value_type", [NoSideEffect]> { + let summary = "Get the result type of a specified `Value`"; + let description = [{ + `pdl_interp.get_value_type` operations get the resulting type of a specific + value. + + Example: + + ```mlir + pdl_interp.get_value_type of %value + ``` + }]; + + let arguments = (ins PDL_Value:$value); + let results = (outs PDL_Type:$result); + let assemblyFormat = "`of` $value attr-dict"; + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value value", [{ + build(builder, state, builder.getType(), value); + }]> + ]; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::InferredTypeOp +//===----------------------------------------------------------------------===// + +def PDLInterp_InferredTypeOp : PDLInterp_Op<"inferred_type"> { + let summary = "Generate a handle to a Type that is \"inferred\""; + let description = [{ + `pdl_interp.inferred_type` operations generate a handle to a type that + should be inferred. This signals to other operations, such as + `pdl_interp.create_operation`, that this type should be inferred. + + Example: + + ```mlir + pdl_interp.inferred_type + ``` + }]; + let results = (outs PDL_Type:$type); + let assemblyFormat = "attr-dict"; + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state", [{ + build(builder, state, builder.getType()); + }]>, + ]; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::IsNotNullOp +//===----------------------------------------------------------------------===// + +def PDLInterp_IsNotNullOp : PDLInterp_PredicateOp<"is_not_null"> { + let summary = "Check if a positional value is non-null"; + let description = [{ + `pdl_interp.is_not_null` operations check that a positional value exists. On + success, this operation branches to the true destination. Otherwise, the + false destination is taken. + + Example: + + ```mlir + pdl_interp.is_not_null %value : !pdl.value -> ^matchDest, ^failureDest + ``` + }]; + + let arguments = (ins PDL_PositionalValue:$value); + let assemblyFormat = "$value `:` type($value) attr-dict `->` successors"; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::RecordMatchOp +//===----------------------------------------------------------------------===// + +def PDLInterp_RecordMatchOp + : PDLInterp_Op<"record_match", [AttrSizedOperandSegments, Terminator]> { + let summary = "Record the metadata for a successful pattern match"; + let description = [{ + `pdl_interp.record_match` operations record a successful pattern match with + the interpreter and branch to the next part of the matcher. The metadata + recorded by these operations correspond to a specific `pdl.pattern`, as well + as what values were used during that match that should be propagated to the + rewriter. + + Example: + + ```mlir + pdl_interp.record_match @rewriters::myRewriter(%root : !pdl.operation) : benefit(1), loc([%root, %op1]), root("foo.op") -> ^nextDest + ``` + }]; + + let arguments = (ins Variadic:$inputs, + Variadic:$matchedOps, + SymbolRefAttr:$rewriter, + OptionalAttr:$rootKind, + OptionalAttr:$generatedOps, + OptionalAttr>:$benefit); + let successors = (successor AnySuccessor:$dest); + let assemblyFormat = [{ + $rewriter (`(` $inputs^ `:` type($inputs) `)`)? `:` + `benefit` `(` $benefit `)` `,` + (`generatedOps` `(` $generatedOps^ `)` `,`)? + `loc` `(` `[` $matchedOps `]` `)` + (`,` `root` `(` $rootKind^ `)`)? attr-dict `->` $dest + }]; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::ReplaceOp +//===----------------------------------------------------------------------===// + +def PDLInterp_ReplaceOp : PDLInterp_Op<"replace"> { + let summary = "Mark an operation as `replaced`"; + let description = [{ + `pdl_interp.replaced` operations are used to specify that an operation + should be marked as replaced. The semantics of this operation correspond + with the `replaceOp` method on a `PatternRewriter`. The set of replacement + values must match the number of results specified by the operation. + + Example: + + ```mlir + // Replace root node with 2 values: + pdl_interp.replace %root with (%val0, %val1) + ``` + }]; + let arguments = (ins PDL_Operation:$operation, + Variadic:$replValues); + let assemblyFormat = "$operation `with` `(` $replValues `)` attr-dict"; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::ReturnOp +//===----------------------------------------------------------------------===// + +def PDLInterp_ReturnOp : PDLInterp_Op<"return", [Terminator]> { + let summary = "Terminate a pattern match or rewrite sequence"; + let description = [{ + `pdl_interp.return` is used to denote the termination of a match or rewrite + sequence. + + Example: + + ```mlir + pdl_interp.return + ``` + }]; + let assemblyFormat = "attr-dict"; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchTypeOp +//===----------------------------------------------------------------------===// + +def PDLInterp_SwitchTypeOp : PDLInterp_SwitchOp<"switch_type"> { + let summary = "Switch on a `Type` value"; + let description = [{ + `pdl_interp.switch_type` operations compare a type with a set of statically + known types. If the value matches one of the provided case values the + destination for that case value is taken, otherwise the default destination + is taken. + + Example: + + ```mlir + pdl_interp.switch_type %type to [i32, i64] -> ^i32Dest, ^i64Dest, ^defaultDest + ``` + }]; + + let arguments = (ins PDL_Type:$value, TypeArrayAttr:$caseValues); + let assemblyFormat = [{ + $value `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest + }]; + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value edge, " + "TypeRange types, Block *defaultDest, ArrayRef dests", [{ + build(builder, state, edge, builder.getTypeArrayAttr(types), defaultDest, + dests); + }]>, + ]; + + let extraClassDeclaration = [{ + auto getCaseTypes() { return caseValues().getAsValueRange(); } + }]; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchAttributeOp +//===----------------------------------------------------------------------===// + +def PDLInterp_SwitchAttributeOp : PDLInterp_SwitchOp<"switch_attribute"> { + let summary = "Switch on the value of an `Attribute`"; + let description = [{ + `pdl_interp.switch_attribute` operations compare the value of a given + attribute with a set of constant attributes. If the value matches one of the + provided case values the destination for that case value is taken, otherwise + the default destination is taken. + + Example: + + ```mlir + pdl_interp.switch_attribute %attr to [10, true] -> ^10Dest, ^trueDest, ^defaultDest + ``` + }]; + let arguments = (ins PDL_Attribute:$attribute, ArrayAttr:$caseValues); + let assemblyFormat = [{ + $attribute `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest + }]; + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value attribute," + "ArrayRef caseValues," + "Block *defaultDest, ArrayRef dests", [{ + build(builder, state, attribute, builder.getArrayAttr(caseValues), + defaultDest, dests); + }]>]; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchOperandCountOp +//===----------------------------------------------------------------------===// + +def PDLInterp_SwitchOperandCountOp + : PDLInterp_SwitchOp<"switch_operand_count"> { + let summary = "Switch on the operand count of an `Operation`"; + let description = [{ + `pdl_interp.switch_operand_count` operations compare the operand count of a + given operation with a set of potential counts. If the value matches one of + the provided case values the destination for that case value is taken, + otherwise the default destination is taken. + + Example: + + ```mlir + pdl_interp.switch_operand_count of %op to [10, 2] -> ^10Dest, ^2Dest, ^defaultDest + ``` + }]; + + let arguments = (ins PDL_Operation:$operation, I32ElementsAttr:$caseValues); + let assemblyFormat = [{ + `of` $operation `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest + }]; + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value operation, " + "ArrayRef counts, Block *defaultDest, " + "ArrayRef dests", [{ + build(builder, state, operation, builder.getI32VectorAttr(counts), + defaultDest, dests); + }]>]; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchOperationNameOp +//===----------------------------------------------------------------------===// + +def PDLInterp_SwitchOperationNameOp + : PDLInterp_SwitchOp<"switch_operation_name"> { + let summary = "Switch on the OperationName of an `Operation`"; + let description = [{ + `pdl_interp.switch_operation_name` operations compare the name of a given + operation with a set of known names. If the value matches one of the + provided case values the destination for that case value is taken, otherwise + the default destination is taken. + + Example: + + ```mlir + pdl_interp.switch_operation_name of %op to ["foo.op", "bar.op"] -> ^fooDest, ^barDest, ^defaultDest + ``` + }]; + + let arguments = (ins PDL_Operation:$operation, + StrArrayAttr:$caseValues); + let assemblyFormat = [{ + `of` $operation `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest + }]; + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value operation, " + "ArrayRef names, " + "Block *defaultDest, ArrayRef dests", [{ + auto stringNames = llvm::to_vector<8>(llvm::map_range(names, + [](OperationName name) { return name.getStringRef(); })); + build(builder, state, operation, builder.getStrArrayAttr(stringNames), + defaultDest, dests); + }]>, + ]; +} + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchResultCountOp +//===----------------------------------------------------------------------===// + +def PDLInterp_SwitchResultCountOp : PDLInterp_SwitchOp<"switch_result_count"> { + let summary = "Switch on the result count of an `Operation`"; + let description = [{ + `pdl_interp.switch_result_count` operations compare the result count of a + given operation with a set of potential counts. If the value matches one of + the provided case values the destination for that case value is taken, + otherwise the default destination is taken. + + Example: + + ```mlir + pdl_interp.switch_result_count of %op to [0, 2] -> ^0Dest, ^2Dest, ^defaultDest + ``` + }]; + + let arguments = (ins PDL_Operation:$operation, I32ElementsAttr:$caseValues); + let assemblyFormat = [{ + `of` $operation `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest + }]; + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value operation, " + "ArrayRef counts, Block *defaultDest, " + "ArrayRef dests", [{ + build(builder, state, operation, builder.getI32VectorAttr(counts), + defaultDest, dests); + }]>]; +} + +#endif // MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -217,12 +217,12 @@ public: template - llvm::iterator_range> getAsRange() { + iterator_range> getAsRange() { return llvm::make_range(attr_value_iterator(begin()), attr_value_iterator(end())); } - template - auto getAsRange() { + template + auto getAsValueRange() { return llvm::map_range(getAsRange(), [](AttrTy attr) { return static_cast(attr.getValue()); }); @@ -589,6 +589,9 @@ /// Returns the number of elements held by this attribute. int64_t getNumElements() const; + /// Returns the number of elements held by this attribute. + int64_t size() const { return getNumElements(); } + /// Generates a new ElementsAttr by mapping each int value to a new /// underlying APInt. The new values can represent either an integer or float. /// This ElementsAttr should contain integers. diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -139,6 +139,7 @@ ArrayAttr getF32ArrayAttr(ArrayRef values); ArrayAttr getF64ArrayAttr(ArrayRef values); ArrayAttr getStrArrayAttr(ArrayRef values); + ArrayAttr getTypeArrayAttr(TypeRange values); // Affine expressions and affine maps. AffineExpr getAffineDimExpr(unsigned position); diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -25,6 +25,7 @@ #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SDBM/SDBMDialect.h" @@ -48,6 +49,7 @@ registerDialect(); registerDialect(); registerDialect(); + registerDialect(); registerDialect(); registerDialect(); registerDialect(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -76,9 +76,7 @@ if (!genericOp) return failure(); - auto mapRange = - genericOp.indexing_maps().getAsRange(); - + auto mapRange = genericOp.indexing_maps().getAsValueRange(); return success( genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && llvm::all_of(mapRange, diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -0,0 +1,122 @@ +//===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/StandardTypes.h" + +using namespace mlir; +using namespace mlir::pdl_interp; + +//===----------------------------------------------------------------------===// +// PDLInterp Dialect +//===----------------------------------------------------------------------===// + +void PDLInterpDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateOperationOp +//===----------------------------------------------------------------------===// + +static ParseResult parseCreateOperationOp(OpAsmParser &p, + OperationState &state) { + if (p.parseOptionalAttrDict(state.attributes)) + return failure(); + Builder &builder = p.getBuilder(); + + // Parse the operation name. + StringAttr opName; + if (p.parseAttribute(opName, "name", state.attributes)) + return failure(); + + // Parse the operands. + SmallVector operands; + if (p.parseLParen() || p.parseOperandList(operands) || p.parseRParen() || + p.resolveOperands(operands, builder.getType(), + state.operands)) + return failure(); + + // Parse the attributes. + SmallVector attrNames; + if (succeeded(p.parseOptionalLBrace())) { + SmallVector attrOps; + do { + StringAttr nameAttr; + OpAsmParser::OperandType operand; + if (p.parseAttribute(nameAttr) || p.parseEqual() || + p.parseOperand(operand)) + return failure(); + attrNames.push_back(nameAttr); + attrOps.push_back(operand); + } while (succeeded(p.parseOptionalComma())); + + if (p.parseRBrace() || + p.resolveOperands(attrOps, builder.getType(), + state.operands)) + return failure(); + } + state.addAttribute("attributeNames", builder.getArrayAttr(attrNames)); + state.addTypes(builder.getType()); + + // Parse the result types. + SmallVector opResultTypes; + if (p.parseArrow()) + return failure(); + if (succeeded(p.parseOptionalLParen())) { + if (p.parseRParen()) + return failure(); + } else if (p.parseOperandList(opResultTypes) || + p.resolveOperands(opResultTypes, builder.getType(), + state.operands)) { + return failure(); + } + + int32_t operandSegmentSizes[] = {static_cast(operands.size()), + static_cast(attrNames.size()), + static_cast(opResultTypes.size())}; + state.addAttribute("operand_segment_sizes", + builder.getI32VectorAttr(operandSegmentSizes)); + return success(); +} + +static void print(OpAsmPrinter &p, CreateOperationOp op) { + p << "pdl_interp.create_operation "; + p.printOptionalAttrDict(op.getAttrs(), + {"attributeNames", "name", "operand_segment_sizes"}); + p << '"' << op.name() << "\"(" << op.operands() << ')'; + + // Emit the optional attributes. + ArrayAttr attrNames = op.attributeNames(); + if (!attrNames.empty()) { + Operation::operand_range attrArgs = op.attributes(); + p << " {"; + interleaveComma(llvm::seq(0, attrNames.size()), p, + [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); + p << '}'; + } + + // Print the result type constraints of the operation. + auto types = op.types(); + if (types.empty()) + p << " -> ()"; + else + p << " -> " << op.types(); +} + +//===----------------------------------------------------------------------===// +// TableGen Auto-Generated Op and Interface Definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc" diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -261,6 +261,12 @@ return getArrayAttr(attrs); } +ArrayAttr Builder::getTypeArrayAttr(TypeRange values) { + auto attrs = llvm::to_vector<8>(llvm::map_range( + values, [this](Type v) -> Attribute { return TypeAttr::get(v); })); + return getArrayAttr(attrs); +} + ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef values) { auto attrs = llvm::to_vector<8>(llvm::map_range( values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); })); diff --git a/mlir/test/Dialect/PDLInterp/ops.mlir b/mlir/test/Dialect/PDLInterp/ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/PDLInterp/ops.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt -split-input-file %s | mlir-opt +// Verify the printed output can be parsed. +// RUN: mlir-opt %s | mlir-opt +// Verify the generic form can be parsed. +// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt + +// ----- + +func @operations(%attribute: !pdl.attribute, + %input: !pdl.value, + %type: !pdl.type) { + // attributes, operands, and results + %op0 = pdl_interp.create_operation "foo.op"(%input) {"attr" = %attribute} -> %type + + // attributes, and results + %op1 = pdl_interp.create_operation "foo.op"() {"attr" = %attribute} -> %type + + // attributes + %op2 = pdl_interp.create_operation "foo.op"() {"attr" = %attribute, "attr1" = %attribute} -> () + + // operands, and results + %op3 = pdl_interp.create_operation "foo.op"(%input) -> %type + + pdl_interp.return +} diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -226,7 +226,7 @@ // If there is only one character, this must either be punctuation or a // single character bare identifier. if (value.size() == 1) - return isalpha(front) || StringRef("_:,=<>()[]?").contains(front); + return isalpha(front) || StringRef("_:,=<>()[]{}?").contains(front); // Check the punctuation that are larger than a single character. if (value == "->") @@ -583,6 +583,8 @@ .Case("=", "Equal()") .Case("<", "Less()") .Case(">", "Greater()") + .Case("{", "LBrace()") + .Case("}", "RBrace()") .Case("(", "LParen()") .Case(")", "RParen()") .Case("[", "LSquare()")