diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -63,6 +63,15 @@ SOURCES_GLOB dialects/async_dialect/*.py DIALECT_NAME async_dialect) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/BufferizationOps.td + SOURCES + dialects/bufferization.py + dialects/_bufferization_ops_ext.py + DIALECT_NAME bufferization) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_ops_ext.py --- a/mlir/python/mlir/dialects/_bufferization_ops_ext.py +++ b/mlir/python/mlir/dialects/_bufferization_ops_ext.py @@ -5,7 +5,7 @@ try: from typing import Sequence, Union from ..ir import * - from ._ods_common import get_default_loc_context as _get_default_loc_context + from ._ods_common import get_default_loc_context from typing import Any, List, Union except ImportError as e: @@ -16,36 +16,17 @@ """Extends the bufferization.alloc_tensor op.""" def __init__(self, - sizes: Union[Sequence[int], Sequence[Value]], - element_type: Type, + tensor_type: Type, + dynamic_sizes: Sequence[Value], *, loc=None, ip=None): - """Constructs an `alloc_tensor` with either static or dynamic sizes.""" + """Constructs an `alloc_tensor` with static and/or dynamic sizes.""" context = get_default_loc_context(loc) - operands = [] - attributes = {} - # TODO: Refactor the AllocTensorOp to take an element type attribute and - # then use normal result type inference, unifying the Python and C++ side - # with a standard mechanism (versus stashing that in builders). - if sizes and isinstance(sizes[0], Value): - # Dynamic sizes. - operands.extend(sizes) - static_size_ints = [-1] * len(sizes) - result_type = RankedTensorType.get(static_size_ints, element_type) - else: - # Static sizes. - result_type = RankedTensorType.get(sizes, element_type) - static_size_ints = sizes - - i64_type = IntegerType.get_signless(64) - attributes["static_sizes"] = ArrayAttr.get( - [IntegerAttr.get(i64_type, s) for s in static_size_ints], - context=context) op = self.build_generic( - results=[result_type], - operands=operands, - attributes=attributes, + results=[tensor_type], + operands=dynamic_sizes, + attributes={}, loc=loc, ip=ip) OpView.__init__(self, op) diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/bufferization.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._bufferization_ops_gen import * diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -252,6 +252,50 @@ ], ) +##---------------------------------------------------------------------------## +# Bufferization dialect. +##---------------------------------------------------------------------------## + +td_library( + name = "BufferizationOpsPyTdFiles", + srcs = [ + "//mlir:include/mlir/Bindings/Python/Attributes.td", + ], + includes = ["../include"], + deps = [ + "//mlir:BufferizableOpInterfaceTdFiles", + "//mlir:BufferizationOpsTdFiles", + "//mlir:OpBaseTdFiles", + ], +) + +gentbl_filegroup( + name = "BufferizationOpsPyGen", + tbl_outs = [ + ( + [ + "-gen-python-op-bindings", + "-bind-dialect=bufferization", + ], + "mlir/dialects/_bufferization_ops_gen.py", + ), + ], + tblgen = "//mlir:mlir-tblgen", + td_file = "mlir/dialects/BufferizationOps.td", + deps = [ + ":BufferizationOpsPyTdFiles", + ], +) + +filegroup( + name = "BufferizationOpsPyFiles", + srcs = [ + "mlir/dialects/_bufferization_ops_ext.py", + "mlir/dialects/bufferization.py", + ":BufferizationOpsPyGen", + ], +) + ##---------------------------------------------------------------------------## # ControlFlow dialect. ##---------------------------------------------------------------------------##