diff --git a/mlir/include/mlir-c/Dialect/PDL.h b/mlir/include/mlir-c/Dialect/PDL.h --- a/mlir/include/mlir-c/Dialect/PDL.h +++ b/mlir/include/mlir-c/Dialect/PDL.h @@ -49,6 +49,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGet(MlirType elementType); +MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type); + //===---------------------------------------------------------------------===// // TypeType //===---------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectPDL.cpp @@ -0,0 +1,102 @@ +//===- DialectPDL.cpp - 'pdl' dialect submodule ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/PDL.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" + +namespace py = pybind11; +using namespace llvm; +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::adaptors; + +void populateDialectPDLSubmodule(const pybind11::module &m) { + //===-------------------------------------------------------------------===// + // PDLType + //===-------------------------------------------------------------------===// + + auto pdlType = mlir_type_subclass(m, "PDLType", mlirTypeIsAPDLType); + + //===-------------------------------------------------------------------===// + // AttributeType + //===-------------------------------------------------------------------===// + + auto attributeType = + mlir_type_subclass(m, "AttributeType", mlirTypeIsAPDLAttributeType); + attributeType.def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirPDLAttributeTypeGet(ctx)); + }, + "Get an instance of AttributeType in given context.", py::arg("cls"), + py::arg("context") = py::none()); + + //===-------------------------------------------------------------------===// + // OperationType + //===-------------------------------------------------------------------===// + + auto operationType = + mlir_type_subclass(m, "OperationType", mlirTypeIsAPDLOperationType); + operationType.def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirPDLOperationTypeGet(ctx)); + }, + "Get an instance of OperationType in given context.", py::arg("cls"), + py::arg("context") = py::none()); + + //===-------------------------------------------------------------------===// + // RangeType + //===-------------------------------------------------------------------===// + + auto rangeType = mlir_type_subclass(m, "RangeType", mlirTypeIsAPDLRangeType); + rangeType.def_classmethod( + "get", + [](py::object cls, MlirType elementType) { + return cls(mlirPDLRangeTypeGet(elementType)); + }, + "Gets an instance of RangeType in the same context as the provided " + "element type.", + py::arg("cls"), py::arg("element_type")); + rangeType.def_property_readonly( + "element_type", + [](MlirType type) { return mlirPDLRangeTypeGetElementType(type); }, + "Get the element type."); + + //===-------------------------------------------------------------------===// + // TypeType + //===-------------------------------------------------------------------===// + + auto typeType = mlir_type_subclass(m, "TypeType", mlirTypeIsAPDLTypeType); + typeType.def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirPDLTypeTypeGet(ctx)); + }, + "Get an instance of TypeType in given context.", py::arg("cls"), + py::arg("context") = py::none()); + + //===-------------------------------------------------------------------===// + // ValueType + //===-------------------------------------------------------------------===// + + auto valueType = mlir_type_subclass(m, "ValueType", mlirTypeIsAPDLValueType); + valueType.def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirPDLValueTypeGet(ctx)); + }, + "Get an instance of TypeType in given context.", py::arg("cls"), + py::arg("context") = py::none()); +} + +PYBIND11_MODULE(_mlirDialectsPDL, m) { + m.doc() = "MLIR PDL dialect."; + populateDialectPDLSubmodule(m); +} diff --git a/mlir/lib/CAPI/Dialect/PDL.cpp b/mlir/lib/CAPI/Dialect/PDL.cpp --- a/mlir/lib/CAPI/Dialect/PDL.cpp +++ b/mlir/lib/CAPI/Dialect/PDL.cpp @@ -60,6 +60,10 @@ return wrap(pdl::RangeType::get(unwrap(elementType))); } +MlirType mlirPDLRangeTypeGetElementType(MlirType type) { + return wrap(unwrap(type).cast().getElementType()); +} + //===---------------------------------------------------------------------===// // TypeType //===---------------------------------------------------------------------===// diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -123,6 +123,15 @@ dialects/quant.py _mlir_libs/_mlir/dialects/quant.pyi) +declare_mlir_python_sources( + MLIRPythonSources.Dialects.pdl + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + SOURCES + dialects/pdl.py + dialects/_pdl_ops_ext.py + _mlir_libs/_mlir/dialects/pdl.pyi) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" @@ -243,6 +252,19 @@ MLIRCAPIQuant ) +declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind + MODULE_NAME _mlirDialectsPDL + ADD_TO_PARENT MLIRPythonSources.Dialects.pdl + ROOT_DIR "${PYTHON_SOURCE_DIR}" + SOURCES + DialectPDL.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPIPDL +) + declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind MODULE_NAME _mlirDialectsSparseTensor ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi @@ -0,0 +1,64 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional + +from mlir.ir import Type, Context + +__all__ = [ + 'PDLType', + 'AttributeType', + 'OperationType', + 'RangeType', + 'TypeType', + 'ValueType', +] + + +class PDLType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + +class AttributeType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Optional[Context] = None) -> AttributeType: ... + + +class OperationType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Optional[Context] = None) -> OperationType: ... + + +class RangeType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(element_type: Type) -> RangeType: ... + + @property + def element_type(self) -> Type: ... + + +class TypeType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Optional[Context] = None) -> TypeType: ... + + +class ValueType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Optional[Context] = None) -> ValueType: ... diff --git a/mlir/python/mlir/dialects/PDLOps.td b/mlir/python/mlir/dialects/PDLOps.td new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/PDLOps.td @@ -0,0 +1,15 @@ +//===-- PDLOps.td - Entry point for PDLOps bind ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_PDL_OPS +#define PYTHON_BINDINGS_PDL_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/PDL/IR/PDLOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -144,7 +144,8 @@ def get_op_results_or_values( - arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _Sequence[_cext.ir.Value]] + arg: _Union[_cext.ir.OpView, _cext.ir.Operation, + _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]]] ) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]: """Returns the given sequence of values or the results of the given op. @@ -157,4 +158,4 @@ elif isinstance(arg, _cext.ir.Operation): return arg.results else: - return arg + return [get_op_result_or_value(element) for element in arg] diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py @@ -0,0 +1,284 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ..dialects import pdl +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union, Optional, Sequence, List, Mapping +from ._ods_common import get_op_result_or_value as _get_value, get_op_results_or_values as _get_values + + +def _get_int_attr(bits: int, value: Union[IntegerAttr, int]) -> IntegerAttr: + """Converts the given value to signless integer attribute of given bit width.""" + if isinstance(value, int): + ty = IntegerType.get_signless(bits) + return IntegerAttr.get(ty, value) + else: + return value + + +def _get_array_attr(attrs: Union[ArrayAttr, Sequence[Attribute]]) -> ArrayAttr: + """Converts the given value to array attribute.""" + if isinstance(attrs, ArrayAttr): + return attrs + else: + return ArrayAttr.get(list(attrs)) + + +def _get_str_array_attr(attrs: Union[ArrayAttr, Sequence[str]]) -> ArrayAttr: + """Converts the given value to string array attribute.""" + if isinstance(attrs, ArrayAttr): + return attrs + else: + return ArrayAttr.get([StringAttr.get(s) for s in attrs]) + + +def _get_str_attr(name: Union[StringAttr, str]) -> Optional[StringAttr]: + """Converts the given value to string attribute.""" + if isinstance(name, str): + return StringAttr.get(name) + else: + return name + + +def _get_type_attr(type: Union[TypeAttr, Type]) -> TypeAttr: + """Converts the given value to type attribute.""" + if isinstance(type, Type): + return TypeAttr.get(type) + else: + return type + + +class ApplyNativeConstraintOp: + """Specialization for PDL apply native constraint op class.""" + + def __init__(self, + name: Union[str, StringAttr], + args: Sequence[Union[OpView, Operation, Value]] = [], + params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, + *, + loc=None, + ip=None): + name = _get_str_attr(name) + args = _get_values(args) + params = params if params is None else _get_array_attr(params) + super().__init__(name, args, params, loc=loc, ip=ip) + + +class ApplyNativeRewriteOp: + """Specialization for PDL apply native rewrite op class.""" + + def __init__(self, + results: Sequence[Type], + name: Union[str, StringAttr], + args: Sequence[Union[OpView, Operation, Value]] = [], + params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, + *, + loc=None, + ip=None): + name = _get_str_attr(name) + args = _get_values(args) + params = params if params is None else _get_array_attr(params) + super().__init__(results, name, args, params, loc=loc, ip=ip) + + +class AttributeOp: + """Specialization for PDL attribute op class.""" + + def __init__(self, + type: Optional[Union[OpView, Operation, Value]] = None, + value: Optional[Attribute] = None, + *, + loc=None, + ip=None): + type = type if type is None else _get_value(type) + result = pdl.AttributeType.get() + super().__init__(result, type, value, loc=loc, ip=ip) + + +class EraseOp: + """Specialization for PDL erase op class.""" + + def __init__(self, + operation: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None): + operation = _get_value(operation) + super().__init__(operation, loc=loc, ip=ip) + + +class OperandOp: + """Specialization for PDL operand op class.""" + + def __init__(self, + type: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None): + type = type if type is None else _get_value(type) + result = pdl.ValueType.get() + super().__init__(result, type, loc=loc, ip=ip) + + +class OperandsOp: + """Specialization for PDL operands op class.""" + + def __init__(self, + types: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None): + types = types if types is None else _get_value(types) + result = pdl.RangeType.get(pdl.ValueType.get()) + super().__init__(result, types, loc=loc, ip=ip) + + +class OperationOp: + """Specialization for PDL operand op class.""" + + def __init__(self, + name: Optional[Union[str, StringAttr]] = None, + args: Sequence[Union[OpView, Operation, Value]] = [], + attributes: Mapping[str, Union[OpView, Operation, Value]] = {}, + types: Sequence[Union[OpView, Operation, Value]] = [], + *, + loc=None, + ip=None): + name = name if name is None else _get_str_attr(name) + args = _get_values(args) + attributeNames = [] + attributeValues = [] + for attrName, attrValue in attributes.items(): + attributeNames.append(StringAttr.get(attrName)) + attributeValues.append(_get_value(attrValue)) + attributeNames = ArrayAttr.get(attributeNames) + types = _get_values(types) + result = pdl.OperationType.get() + super().__init__(result, name, args, attributeValues, attributeNames, types, loc=loc, ip=ip) + + +class PatternOp: + """Specialization for PDL pattern op class.""" + + def __init__(self, + benefit: Union[IntegerAttr, int], + name: Optional[Union[StringAttr, str]] = None, + *, + loc=None, + ip=None): + """Creates an PDL `pattern` operation.""" + name_attr = None if name is None else _get_str_attr(name) + benefit_attr = _get_int_attr(16, benefit) + super().__init__(benefit_attr, name_attr, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self): + """Return the body (block) of the pattern.""" + return self.regions[0].blocks[0] + + +class ReplaceOp: + """Specialization for PDL replace op class.""" + + def __init__(self, + op: Union[OpView, Operation, Value], + *, + with_op: Optional[Union[OpView, Operation, Value]] = None, + with_values: Sequence[Union[OpView, Operation, Value]] = [], + loc=None, + ip=None): + op = _get_value(op) + with_op = with_op if with_op is None else _get_value(with_op) + with_values = _get_values(with_values) + super().__init__(op, with_op, with_values, loc=loc, ip=ip) + + +class ResultOp: + """Specialization for PDL result op class.""" + + def __init__(self, + parent: Union[OpView, Operation, Value], + index: Union[IntegerAttr, int], + *, + loc=None, + ip=None): + index = _get_int_attr(32, index) + parent = _get_value(parent) + result = pdl.ValueType.get() + super().__init__(result, parent, index, loc=loc, ip=ip) + + +class ResultsOp: + """Specialization for PDL results op class.""" + + def __init__(self, + result: Type, + parent: Union[OpView, Operation, Value], + index: Optional[Union[IntegerAttr, int]] = None, + *, + loc=None, + ip=None): + parent = _get_value(parent) + index = index if index is None else _get_int_attr(32, index) + super().__init__(result, parent, index, loc=loc, ip=ip) + + +class RewriteOp: + """Specialization for PDL rewrite op class.""" + + def __init__(self, + root: Optional[Union[OpView, Operation, Value]] = None, + name: Optional[Union[StringAttr, str]] = None, + args: Sequence[Union[OpView, Operation, Value]] = [], + params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, + *, + loc=None, + ip=None): + root = root if root is None else _get_value(root) + name = name if name is None else _get_str_attr(name) + args = _get_values(args) + params = params if params is None else _get_array_attr(params) + super().__init__(root, name, args, params, loc=loc, ip=ip) + + def add_body(self): + """Add body (block) to the rewrite.""" + self.regions[0].blocks.append() + return self.body + + @property + def body(self): + """Return the body (block) of the rewrite.""" + return self.regions[0].blocks[0] + + +class TypeOp: + """Specialization for PDL type op class.""" + + def __init__(self, + type: Optional[Union[TypeAttr, Type]] = None, + *, + loc=None, + ip=None): + type = type if type is None else _get_type_attr(type) + result = pdl.TypeType.get() + super().__init__(result, type, loc=loc, ip=ip) + + +class TypesOp: + """Specialization for PDL types op class.""" + + def __init__(self, + types: Sequence[Union[TypeAttr, Type]] = [], + *, + loc=None, + ip=None): + types = _get_array_attr([_get_type_attr(ty) for ty in types]) + types = None if not types else types + result = pdl.RangeType.get(pdl.TypeType.get()) + super().__init__(result, types, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/pdl.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._pdl_ops_gen import * +from .._mlir_libs._mlirDialectsPDL import * diff --git a/mlir/test/CAPI/pdl.c b/mlir/test/CAPI/pdl.c --- a/mlir/test/CAPI/pdl.c +++ b/mlir/test/CAPI/pdl.c @@ -146,6 +146,7 @@ MlirType parsedType = mlirTypeParseGet( ctx, mlirStringRefCreateFromCString("!pdl.range")); MlirType constructedType = mlirPDLRangeTypeGet(typeType); + MlirType elementType = mlirPDLRangeTypeGetElementType(constructedType); assert(!mlirTypeIsNull(typeType) && "couldn't get PDLTypeType"); assert(!mlirTypeIsNull(parsedType) && "couldn't parse PDLAttributeType"); @@ -191,11 +192,15 @@ // CHECK: equal: 1 fprintf(stderr, "equal: %d\n", mlirTypeEqual(parsedType, constructedType)); + // CHECK: equal: 1 + fprintf(stderr, "equal: %d\n", mlirTypeEqual(typeType, elementType)); // CHECK: !pdl.range mlirTypeDump(parsedType); // CHECK: !pdl.range mlirTypeDump(constructedType); + // CHECK: !pdl.type + mlirTypeDump(elementType); fprintf(stderr, "\n\n"); } diff --git a/mlir/test/python/dialects/pdl_ops.py b/mlir/test/python/dialects/pdl_ops.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/pdl_ops.py @@ -0,0 +1,318 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects.pdl import * + + +def constructAndPrintInModule(f): + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + f() + print(module) + return f + + +# CHECK: module { +# CHECK: pdl.pattern @operations : benefit(1) { +# CHECK: %0 = pdl.attribute +# CHECK: %1 = pdl.type +# CHECK: %2 = pdl.operation {"attr" = %0} -> (%1 : !pdl.type) +# CHECK: %3 = pdl.result 0 of %2 +# CHECK: %4 = pdl.operand +# CHECK: %5 = pdl.operation(%3, %4 : !pdl.value, !pdl.value) +# CHECK: pdl.rewrite %5 with "rewriter" +# CHECK: } +# CHECK: } +@constructAndPrintInModule +def test_operations(): + pattern = PatternOp(1, "operations") + with InsertionPoint(pattern.body): + attr = AttributeOp() + ty = TypeOp() + op0 = OperationOp(attributes={"attr": attr}, types=[ty]) + op0_result = ResultOp(op0, 0) + input = OperandOp() + root = OperationOp(args=[op0_result, input]) + RewriteOp(root, "rewriter") + + +# CHECK: module { +# CHECK: pdl.pattern @rewrite_with_args : benefit(1) { +# CHECK: %0 = pdl.operand +# CHECK: %1 = pdl.operation(%0 : !pdl.value) +# CHECK: pdl.rewrite %1 with "rewriter"(%0 : !pdl.value) +# CHECK: } +# CHECK: } +@constructAndPrintInModule +def test_rewrite_with_args(): + pattern = PatternOp(1, "rewrite_with_args") + with InsertionPoint(pattern.body): + input = OperandOp() + root = OperationOp(args=[input]) + RewriteOp(root, "rewriter", args=[input]) + +# CHECK: module { +# CHECK: pdl.pattern @rewrite_with_params : benefit(1) { +# CHECK: %0 = pdl.operation +# CHECK: pdl.rewrite %0 with "rewriter" ["I am param"] +# CHECK: } +# CHECK: } +@constructAndPrintInModule +def test_rewrite_with_params(): + pattern = PatternOp(1, "rewrite_with_params") + with InsertionPoint(pattern.body): + op = OperationOp() + RewriteOp(op, "rewriter", params=[StringAttr.get("I am param")]) + +# CHECK: module { +# CHECK: pdl.pattern @rewrite_with_args_and_params : benefit(1) { +# CHECK: %0 = pdl.operand +# CHECK: %1 = pdl.operation(%0 : !pdl.value) +# CHECK: pdl.rewrite %1 with "rewriter" ["I am param"](%0 : !pdl.value) +# CHECK: } +# CHECK: } +@constructAndPrintInModule +def test_rewrite_with_args_and_params(): + pattern = PatternOp(1, "rewrite_with_args_and_params") + with InsertionPoint(pattern.body): + input = OperandOp() + root = OperationOp(args=[input]) + RewriteOp(root, "rewriter", params=[StringAttr.get("I am param")], args=[input]) + +# CHECK: module { +# CHECK: pdl.pattern @rewrite_multi_root_optimal : benefit(1) { +# CHECK: %0 = pdl.operand +# CHECK: %1 = pdl.operand +# CHECK: %2 = pdl.type +# CHECK: %3 = pdl.operation(%0 : !pdl.value) -> (%2 : !pdl.type) +# CHECK: %4 = pdl.result 0 of %3 +# CHECK: %5 = pdl.operation(%4 : !pdl.value) +# CHECK: %6 = pdl.operation(%1 : !pdl.value) -> (%2 : !pdl.type) +# CHECK: %7 = pdl.result 0 of %6 +# CHECK: %8 = pdl.operation(%4, %7 : !pdl.value, !pdl.value) +# CHECK: pdl.rewrite with "rewriter" ["I am param"](%5, %8 : !pdl.operation, !pdl.operation) +# CHECK: } +# CHECK: } +@constructAndPrintInModule +def test_rewrite_multi_root_optimal(): + pattern = PatternOp(1, "rewrite_multi_root_optimal") + with InsertionPoint(pattern.body): + input1 = OperandOp() + input2 = OperandOp() + ty = TypeOp() + op1 = OperationOp(args=[input1], types=[ty]) + val1 = ResultOp(op1, 0) + root1 = OperationOp(args=[val1]) + op2 = OperationOp(args=[input2], types=[ty]) + val2 = ResultOp(op2, 0) + root2 = OperationOp(args=[val1, val2]) + RewriteOp(name="rewriter", params=[StringAttr.get("I am param")], args=[root1, root2]) + +# CHECK: module { +# CHECK: pdl.pattern @rewrite_multi_root_forced : benefit(1) { +# CHECK: %0 = pdl.operand +# CHECK: %1 = pdl.operand +# CHECK: %2 = pdl.type +# CHECK: %3 = pdl.operation(%0 : !pdl.value) -> (%2 : !pdl.type) +# CHECK: %4 = pdl.result 0 of %3 +# CHECK: %5 = pdl.operation(%4 : !pdl.value) +# CHECK: %6 = pdl.operation(%1 : !pdl.value) -> (%2 : !pdl.type) +# CHECK: %7 = pdl.result 0 of %6 +# CHECK: %8 = pdl.operation(%4, %7 : !pdl.value, !pdl.value) +# CHECK: pdl.rewrite %5 with "rewriter" ["I am param"](%8 : !pdl.operation) +# CHECK: } +# CHECK: } +@constructAndPrintInModule +def test_rewrite_multi_root_forced(): + pattern = PatternOp(1, "rewrite_multi_root_forced") + with InsertionPoint(pattern.body): + input1 = OperandOp() + input2 = OperandOp() + ty = TypeOp() + op1 = OperationOp(args=[input1], types=[ty]) + val1 = ResultOp(op1, 0) + root1 = OperationOp(args=[val1]) + op2 = OperationOp(args=[input2], types=[ty]) + val2 = ResultOp(op2, 0) + root2 = OperationOp(args=[val1, val2]) + RewriteOp(root1, name="rewriter", params=[StringAttr.get("I am param")], args=[root2]) + +# CHECK: module { +# CHECK: pdl.pattern @rewrite_add_body : benefit(1) { +# CHECK: %0 = pdl.type : i32 +# CHECK: %1 = pdl.type +# CHECK: %2 = pdl.operation -> (%0, %1 : !pdl.type, !pdl.type) +# CHECK: pdl.rewrite %2 { +# CHECK: %3 = pdl.type +# CHECK: %4 = pdl.operation "foo.op" -> (%0, %3 : !pdl.type, !pdl.type) +# CHECK: pdl.replace %2 with %4 +# CHECK: } +# CHECK: } +# CHECK: } +@constructAndPrintInModule +def test_rewrite_add_body(): + pattern = PatternOp(1, "rewrite_add_body") + with InsertionPoint(pattern.body): + ty1 = TypeOp(IntegerType.get_signless(32)) + ty2 = TypeOp() + root = OperationOp(types=[ty1, ty2]) + rewrite = RewriteOp(root) + with InsertionPoint(rewrite.add_body()): + ty3 = TypeOp() + newOp = OperationOp(name="foo.op", types=[ty1, ty3]) + ReplaceOp(root, with_op=newOp) + +# CHECK: module { +# CHECK: pdl.pattern @rewrite_type : benefit(1) { +# CHECK: %0 = pdl.type : i32 +# CHECK: %1 = pdl.type +# CHECK: %2 = pdl.operation -> (%0, %1 : !pdl.type, !pdl.type) +# CHECK: pdl.rewrite %2 { +# CHECK: %3 = pdl.operation "foo.op" -> (%0, %1 : !pdl.type, !pdl.type) +# CHECK: } +# CHECK: } +# CHECK: } +@constructAndPrintInModule +def test_rewrite_type(): + pattern = PatternOp(1, "rewrite_type") + with InsertionPoint(pattern.body): + ty1 = TypeOp(IntegerType.get_signless(32)) + ty2 = TypeOp() + root = OperationOp(types=[ty1, ty2]) + rewrite = RewriteOp(root) + with InsertionPoint(rewrite.add_body()): + newOp = OperationOp(name="foo.op", types=[ty1, ty2]) + +# CHECK: module { +# CHECK: pdl.pattern @rewrite_types : benefit(1) { +# CHECK: %0 = pdl.types +# CHECK: %1 = pdl.operation -> (%0 : !pdl.range) +# CHECK: pdl.rewrite %1 { +# CHECK: %2 = pdl.types : [i32, i64] +# CHECK: %3 = pdl.operation "foo.op" -> (%0, %2 : !pdl.range, !pdl.range) +# CHECK: } +# CHECK: } +# CHECK: } +@constructAndPrintInModule +def test_rewrite_types(): + pattern = PatternOp(1, "rewrite_types") + with InsertionPoint(pattern.body): + types = TypesOp() + root = OperationOp(types=[types]) + rewrite = RewriteOp(root) + with InsertionPoint(rewrite.add_body()): + otherTypes = TypesOp([IntegerType.get_signless(32), IntegerType.get_signless(64)]) + newOp = OperationOp(name="foo.op", types=[types, otherTypes]) + +# CHECK: module { +# CHECK: pdl.pattern @rewrite_operands : benefit(1) { +# CHECK: %0 = pdl.types +# CHECK: %1 = pdl.operands : %0 +# CHECK: %2 = pdl.operation(%1 : !pdl.range) +# CHECK: pdl.rewrite %2 { +# CHECK: %3 = pdl.operation "foo.op" -> (%0 : !pdl.range) +# CHECK: } +# CHECK: } +# CHECK: } +@constructAndPrintInModule +def test_rewrite_operands(): + pattern = PatternOp(1, "rewrite_operands") + with InsertionPoint(pattern.body): + types = TypesOp() + operands = OperandsOp(types) + root = OperationOp(args=[operands]) + rewrite = RewriteOp(root) + with InsertionPoint(rewrite.add_body()): + newOp = OperationOp(name="foo.op", types=[types]) + +# CHECK: module { +# CHECK: pdl.pattern @native_rewrite : benefit(1) { +# CHECK: %0 = pdl.operation +# CHECK: pdl.rewrite %0 { +# CHECK: pdl.apply_native_rewrite "NativeRewrite"(%0 : !pdl.operation) +# CHECK: } +# CHECK: } +# CHECK: } +@constructAndPrintInModule +def test_native_rewrite(): + pattern = PatternOp(1, "native_rewrite") + with InsertionPoint(pattern.body): + root = OperationOp() + rewrite = RewriteOp(root) + with InsertionPoint(rewrite.add_body()): + ApplyNativeRewriteOp([], "NativeRewrite", args=[root]) + +# CHECK: module { +# CHECK: pdl.pattern @attribute_with_value : benefit(1) { +# CHECK: %0 = pdl.operation +# CHECK: pdl.rewrite %0 { +# CHECK: %1 = pdl.attribute "value" +# CHECK: pdl.apply_native_rewrite "NativeRewrite"(%1 : !pdl.attribute) +# CHECK: } +# CHECK: } +# CHECK: } +@constructAndPrintInModule +def test_attribute_with_value(): + pattern = PatternOp(1, "attribute_with_value") + with InsertionPoint(pattern.body): + root = OperationOp() + rewrite = RewriteOp(root) + with InsertionPoint(rewrite.add_body()): + attr = AttributeOp(value=Attribute.parse('"value"')) + ApplyNativeRewriteOp([], "NativeRewrite", args=[attr]) + +# CHECK: module { +# CHECK: pdl.pattern @erase : benefit(1) { +# CHECK: %0 = pdl.operation +# CHECK: pdl.rewrite %0 { +# CHECK: pdl.erase %0 +# CHECK: } +# CHECK: } +# CHECK: } +@constructAndPrintInModule +def test_erase(): + pattern = PatternOp(1, "erase") + with InsertionPoint(pattern.body): + root = OperationOp() + rewrite = RewriteOp(root) + with InsertionPoint(rewrite.add_body()): + EraseOp(root) + +# CHECK: module { +# CHECK: pdl.pattern @operation_results : benefit(1) { +# CHECK: %0 = pdl.types +# CHECK: %1 = pdl.operation -> (%0 : !pdl.range) +# CHECK: %2 = pdl.results of %1 +# CHECK: %3 = pdl.operation(%2 : !pdl.range) +# CHECK: pdl.rewrite %3 with "rewriter" +# CHECK: } +# CHECK: } +@constructAndPrintInModule +def test_operation_results(): + valueRange = RangeType.get(ValueType.get()) + pattern = PatternOp(1, "operation_results") + with InsertionPoint(pattern.body): + types = TypesOp() + inputOp = OperationOp(types=[types]) + results = ResultsOp(valueRange, inputOp) + root = OperationOp(args=[results]) + RewriteOp(root, name="rewriter") + +# CHECK: module { +# CHECK: pdl.pattern : benefit(1) { +# CHECK: %0 = pdl.type +# CHECK: pdl.apply_native_constraint "typeConstraint" [](%0 : !pdl.type) +# CHECK: %1 = pdl.operation -> (%0 : !pdl.type) +# CHECK: pdl.rewrite %1 with "rewrite" +# CHECK: } +# CHECK: } +@constructAndPrintInModule +def test_apply_native_constraint(): + pattern = PatternOp(1) + with InsertionPoint(pattern.body): + resultType = TypeOp() + ApplyNativeConstraintOp("typeConstraint", args=[resultType], params=[]) + root = OperationOp(types=[resultType]) + RewriteOp(root, name="rewrite") diff --git a/mlir/test/python/dialects/pdl_types.py b/mlir/test/python/dialects/pdl_types.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/pdl_types.py @@ -0,0 +1,150 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import pdl + + +def run(f): + print("\nTEST:", f.__name__) + f() + return f + + +# CHECK-LABEL: TEST: test_attribute_type +@run +def test_attribute_type(): + with Context(): + parsedType = Type.parse("!pdl.attribute") + constructedType = pdl.AttributeType.get() + + assert pdl.AttributeType.isinstance(parsedType) + assert not pdl.OperationType.isinstance(parsedType) + assert not pdl.RangeType.isinstance(parsedType) + assert not pdl.TypeType.isinstance(parsedType) + assert not pdl.ValueType.isinstance(parsedType) + + assert pdl.AttributeType.isinstance(constructedType) + assert not pdl.OperationType.isinstance(constructedType) + assert not pdl.RangeType.isinstance(constructedType) + assert not pdl.TypeType.isinstance(constructedType) + assert not pdl.ValueType.isinstance(constructedType) + + assert parsedType == constructedType + + # CHECK: !pdl.attribute + print(parsedType) + # CHECK: !pdl.attribute + print(constructedType) + + +# CHECK-LABEL: TEST: test_operation_type +@run +def test_operation_type(): + with Context(): + parsedType = Type.parse("!pdl.operation") + constructedType = pdl.OperationType.get() + + assert not pdl.AttributeType.isinstance(parsedType) + assert pdl.OperationType.isinstance(parsedType) + assert not pdl.RangeType.isinstance(parsedType) + assert not pdl.TypeType.isinstance(parsedType) + assert not pdl.ValueType.isinstance(parsedType) + + assert not pdl.AttributeType.isinstance(constructedType) + assert pdl.OperationType.isinstance(constructedType) + assert not pdl.RangeType.isinstance(constructedType) + assert not pdl.TypeType.isinstance(constructedType) + assert not pdl.ValueType.isinstance(constructedType) + + assert parsedType == constructedType + + # CHECK: !pdl.operation + print(parsedType) + # CHECK: !pdl.operation + print(constructedType) + + +# CHECK-LABEL: TEST: test_range_type +@run +def test_range_type(): + with Context(): + typeType = Type.parse("!pdl.type") + parsedType = Type.parse("!pdl.range") + constructedType = pdl.RangeType.get(typeType) + elementType = constructedType.element_type + + assert not pdl.AttributeType.isinstance(parsedType) + assert not pdl.OperationType.isinstance(parsedType) + assert pdl.RangeType.isinstance(parsedType) + assert not pdl.TypeType.isinstance(parsedType) + assert not pdl.ValueType.isinstance(parsedType) + + assert not pdl.AttributeType.isinstance(constructedType) + assert not pdl.OperationType.isinstance(constructedType) + assert pdl.RangeType.isinstance(constructedType) + assert not pdl.TypeType.isinstance(constructedType) + assert not pdl.ValueType.isinstance(constructedType) + + assert parsedType == constructedType + assert elementType == typeType + + # CHECK: !pdl.range + print(parsedType) + # CHECK: !pdl.range + print(constructedType) + # CHECK: !pdl.type + print(elementType) + + +# CHECK-LABEL: TEST: test_type_type +@run +def test_type_type(): + with Context(): + parsedType = Type.parse("!pdl.type") + constructedType = pdl.TypeType.get() + + assert not pdl.AttributeType.isinstance(parsedType) + assert not pdl.OperationType.isinstance(parsedType) + assert not pdl.RangeType.isinstance(parsedType) + assert pdl.TypeType.isinstance(parsedType) + assert not pdl.ValueType.isinstance(parsedType) + + assert not pdl.AttributeType.isinstance(constructedType) + assert not pdl.OperationType.isinstance(constructedType) + assert not pdl.RangeType.isinstance(constructedType) + assert pdl.TypeType.isinstance(constructedType) + assert not pdl.ValueType.isinstance(constructedType) + + assert parsedType == constructedType + + # CHECK: !pdl.type + print(parsedType) + # CHECK: !pdl.type + print(constructedType) + + +# CHECK-LABEL: TEST: test_value_type +@run +def test_value_type(): + with Context(): + parsedType = Type.parse("!pdl.value") + constructedType = pdl.ValueType.get() + + assert not pdl.AttributeType.isinstance(parsedType) + assert not pdl.OperationType.isinstance(parsedType) + assert not pdl.RangeType.isinstance(parsedType) + assert not pdl.TypeType.isinstance(parsedType) + assert pdl.ValueType.isinstance(parsedType) + + assert not pdl.AttributeType.isinstance(constructedType) + assert not pdl.OperationType.isinstance(constructedType) + assert not pdl.RangeType.isinstance(constructedType) + assert not pdl.TypeType.isinstance(constructedType) + assert pdl.ValueType.isinstance(constructedType) + + assert parsedType == constructedType + + # CHECK: !pdl.value + print(parsedType) + # CHECK: !pdl.value + print(constructedType)