diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py --- a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py @@ -2,6 +2,47 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Optional, Sequence, Union +from ..ir import * +from ._ods_common import get_default_loc_context + + +class InitTensorOp: + + def __init__(self, + sizes: Union[Sequence[int], Sequence[Value]], + element_type: Type, + *, + loc=None, + ip=None): + """Constructs an `init_tensor` with either static or dynamic sizes.""" + context = get_default_loc_context(loc) + operands = [] + attributes = {} + # TODO: Refactor the InitTensorOp to take an element type attribute and + # then use normal result type inference, unifying the Python and C++ side + # with a standard mechanism (versus stashing that in builders). + if sizes and isinstance(sizes[0], Value): + # Dynamic sizes. + operands.extend(sizes) + static_size_ints = [-1] * len(sizes) + result_type = RankedTensorType.get(static_size_ints, element_type) + else: + # Static sizes. + result_type = RankedTensorType.get(sizes, element_type) + static_size_ints = sizes + + index_type = IndexType.get(context) + attributes["static_sizes"] = ArrayAttr.get( + [IntegerAttr.get(index_type, s) for s in static_size_ints], + context=context) + op = self.build_generic(results=[result_type], + operands=operands, + attributes=attributes, + loc=loc, + ip=ip) + OpView.__init__(self, op) + class StructuredOpMixin: """All structured ops use the same mixin class.""" diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py b/mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py --- a/mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py @@ -17,14 +17,15 @@ """Decorator to extend an OpView class from an extension module. Extension modules can expose various entry-points: + Stand-alone class with the same name as a parent OpView class (i.e. + "ReturnOp"). A name-based match is attempted first before falling back + to a below mechanism. + def select_opview_mixin(parent_opview_cls): If defined, allows an appropriate mixin class to be selected dynamically based on the parent OpView class. Should return NotImplemented if a decision is not made. - Stand-alone class with the same name as a parent OpView class (i.e. - "ReturnOp"). - Args: ext_module: A module from which to locate extensions. Can be None if not available. @@ -38,16 +39,18 @@ if ext_module is None: return parent_opview_cls mixin_cls = NotImplemented + # First try to resolve by name. try: - select_mixin = getattr(ext_module, "select_opview_mixin") + mixin_cls = getattr(ext_module, parent_opview_cls.__name__) except AttributeError: - # Try to default resolve it. + # Fall back to a select_opview_mixin hook. try: - mixin_cls = getattr(ext_module, parent_opview_cls.__name__) + select_mixin = getattr(ext_module, "select_opview_mixin") except AttributeError: pass - else: - mixin_cls = select_mixin(parent_opview_cls) + else: + mixin_cls = select_mixin(parent_opview_cls) + if mixin_cls is NotImplemented or mixin_cls is None: return parent_opview_cls diff --git a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py --- a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py +++ b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py @@ -55,15 +55,7 @@ @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)) def test_matmul_mono(lhs, rhs): - # TODO: Enable outs inference and add sugar for InitTensorOp - # construction. - init_result = linalg.InitTensorOp(result=RankedTensorType.get((4, 8), - f32), - static_sizes=ArrayAttr.get([ - IntegerAttr.get(IndexType.get(), 4), - IntegerAttr.get(IndexType.get(), 8) - ]), - sizes=[]) + init_result = linalg.InitTensorOp([4, 8], f32) return matmul_mono(lhs, rhs, outs=[init_result.result]) # CHECK-LABEL: @test_i8i8i32_matmul diff --git a/mlir/test/Bindings/Python/dialects/linalg/ops.py b/mlir/test/Bindings/Python/dialects/linalg/ops.py --- a/mlir/test/Bindings/Python/dialects/linalg/ops.py +++ b/mlir/test/Bindings/Python/dialects/linalg/ops.py @@ -9,9 +9,39 @@ def run(f): print("\nTEST:", f.__name__) f() + return f + + +# CHECK-LABEL: TEST: testInitTensor +@run +def testInitTensor(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint.at_block_terminator(module.body): + # CHECK-LABEL: func @static_sizes + # CHECK: %0 = linalg.init_tensor [3, 4] : tensor<3x4xf32> + @builtin.FuncOp.from_py_func() + def static_sizes(): + return linalg.InitTensorOp([3, 4], f32) + + # CHECK-LABEL: func @dynamic_sizes + # CHECK: %0 = linalg.init_tensor [%arg0, %arg1] : tensor + @builtin.FuncOp.from_py_func(IndexType.get(), IndexType.get()) + def dynamic_sizes(d0, d1): + return linalg.InitTensorOp([d0, d1], f32) + + # CHECK-LABEL: func @zero_d + # CHECK: %0 = linalg.init_tensor [] : tensor + @builtin.FuncOp.from_py_func() + def zero_d(): + return linalg.InitTensorOp([], f32) + + print(module) # CHECK-LABEL: TEST: testStructuredOpOnTensors +@run def testStructuredOpOnTensors(): with Context() as ctx, Location.unknown(): module = Module.create() @@ -31,10 +61,8 @@ print(module) -run(testStructuredOpOnTensors) - - # CHECK-LABEL: TEST: testStructuredOpOnBuffers +@run def testStructuredOpOnBuffers(): with Context() as ctx, Location.unknown(): module = Module.create() @@ -52,6 +80,3 @@ # CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>) print(module) - - -run(testStructuredOpOnBuffers)