diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -109,6 +109,15 @@ SOURCES dialects/python_test.py DIALECT_NAME python_test) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SCFOps.td + SOURCES + dialects/scf.py + dialects/_scf_ops_ext.py + DIALECT_NAME scf) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/SCFOps.td b/mlir/python/mlir/dialects/SCFOps.td new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/SCFOps.td @@ -0,0 +1,15 @@ +//===-- SCFOps.td - Entry point for SCF dialect bindings ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_SCF_OPS +#define PYTHON_BINDINGS_SCF_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/SCF/SCFOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/_scf_ops_ext.py @@ -0,0 +1,57 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Any, Sequence + + +class ForOp: + """Specialization for the SCF for op class.""" + + def __init__(self, + lower_bound, + upper_bound, + step, + iter_args: Sequence[Any] = [], + *, + loc=None, + ip=None): + """Creates an SCF `for` operation. + + - `lower_bound` is the value to use as lower bound of the loop. + - `upper_bound` is the value to use as upper bound of the loop. + - `step` is the value to use as loop step. + - `iter_args` is a list of additional loop-carried arguments. + """ + results = [arg.type for arg in iter_args] + super().__init__( + self.build_generic( + regions=1, + results=results, + operands=[lower_bound, upper_bound, step] + list(iter_args), + loc=loc, + ip=ip)) + self.regions[0].blocks.append(IndexType.get(), *results) + + @property + def body(self): + """Returns the body (block) of the loop.""" + return self.regions[0].blocks[0] + + @property + def induction_variable(self): + """Returns the induction variable of the loop.""" + return self.body.arguments[0] + + @property + def inner_iter_args(self): + """Returns the loop-carried arguments usable within the loop. + + To obtain the loop-carried operands, use `iter_args`. + """ + return self.body.arguments[1:] diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/scf.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._scf_ops_gen import * diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/scf.py @@ -0,0 +1,54 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import scf +from mlir.dialects import builtin + + +def run(f): + print("\nTEST:", f.__name__) + f() + return f + + +# CHECK-LABEL: TEST: testSimpleLoop +@run +def testSimpleLoop(): + with Context(), Location.unknown(): + module = Module.create() + index_type = IndexType.get() + with InsertionPoint(module.body): + + @builtin.FuncOp.from_py_func(index_type, index_type, index_type) + def simple_loop(lb, ub, step): + loop = scf.ForOp(lb, ub, step, [lb, lb]) + with InsertionPoint(loop.body): + scf.YieldOp(loop.inner_iter_args) + return + + # CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) + # CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] + # CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]]) + # CHECK: scf.yield %[[I1]], %[[I2]] + print(module) + + +# CHECK-LABEL: TEST: testInductionVar +@run +def testInductionVar(): + with Context(), Location.unknown(): + module = Module.create() + index_type = IndexType.get() + with InsertionPoint(module.body): + + @builtin.FuncOp.from_py_func(index_type, index_type, index_type) + def induction_var(lb, ub, step): + loop = scf.ForOp(lb, ub, step, [lb]) + with InsertionPoint(loop.body): + scf.YieldOp([loop.induction_variable]) + return + + # CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) + # CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] + # CHECK: scf.yield %[[IV]] + print(module)