diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/_memref_ops_ext.py @@ -0,0 +1,37 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union + + +class LoadOp: + """Specialization for the MemRef load operation.""" + + def __init__(self, + memref: Union[Operation, OpView, Value], + indices: Optional[Union[Operation, OpView, + Sequence[Value]]] = None, + *, + loc=None, + ip=None): + """Creates a memref load operation. + + Args: + memref: the buffer to load from. + indices: the list of subscripts, may be empty for zero-dimensional + buffers. + loc: user-visible location of the operation. + ip: insertion point. + """ + memref_resolved = _get_op_result_or_value(memref) + indices_resolved = [] if indices is None else _get_op_results_or_values( + indices) + return_type = memref_resolved.type + super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip) diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py --- a/mlir/test/python/dialects/memref.py +++ b/mlir/test/python/dialects/memref.py @@ -7,8 +7,10 @@ def run(f): print("\nTEST:", f.__name__) f() + return f # CHECK-LABEL: TEST: testSubViewAccessors +@run def testSubViewAccessors(): ctx = Context() module = Module.parse(r""" @@ -49,5 +51,20 @@ print(subview.strides[1]) -run(testSubViewAccessors) +# CHECK-LABEL: TEST: testCustomBuidlers +@run +def testCustomBuidlers(): + with Context() as ctx, Location.unknown(ctx): + module = Module.parse(r""" + func @f1(%arg0: memref, %arg1: index, %arg2: index) { + return + } + """) + func = module.body.operations[0] + func_body = func.regions[0].blocks[0] + with InsertionPoint.at_block_terminator(func_body): + memref.LoadOp(func.arguments[0], func.arguments[1:]) + # CHECK: func @f1(%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) + # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] + print(module)