diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -17,6 +17,14 @@ extern "C" { #endif +/// Return true if `name` is the name of a builtin named Linalg op. +bool mlirLinalgIsBuiltinNamedOp(MlirDialect linalgDialect, MlirStringRef name); + +/// Apply the special region builder for the builtin named Linalg op. +/// Assert that `op` is a builtin named Linalg op. +void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, + MlirOperation op); + MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); #ifdef __cplusplus diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -37,6 +37,14 @@ let dependentDialects = [ "AffineDialect", "StandardOpsDialect", "tensor::TensorDialect" ]; + let extraClassDeclaration = [{ + using RegionBuilderFunType = std::function)>; + RegionBuilderFunType getRegionBuilder(StringRef name) { + return namedStructuredOpRegionBuilders.lookup(name); + } + private: + llvm::StringMap namedStructuredOpRegionBuilders; + }]; } // Whether a type is a RangeType. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -14,6 +14,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Types.h" +#include "llvm/ADT/StringMap.h" #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc" diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -119,6 +119,18 @@ add_subdirectory(Transforms) add_subdirectory(Conversions) +add_mlir_python_extension(MLIRLinalgBindingsPythonExtension _mlirLinalg + INSTALL_DIR + python + SOURCES + Linalg.cpp +) +target_link_libraries(MLIRLinalgBindingsPythonExtension + PRIVATE + MLIRCAPILinalg +) +add_dependencies(MLIRBindingsPythonExtension MLIRLinalgBindingsPythonExtension) + add_mlir_python_extension(MLIRLinalgPassesBindingsPythonExtension _mlirLinalgPasses INSTALL_DIR python diff --git a/mlir/lib/Bindings/Python/Linalg.cpp b/mlir/lib/Bindings/Python/Linalg.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/Linalg.cpp @@ -0,0 +1,35 @@ +//===- Linalg.cpp - Pybind module for Linalg APIs ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Linalg.h" +#include "IRModule.h" +#include "mlir-c/IR.h" + +#include + +namespace py = pybind11; +using namespace mlir; +using namespace mlir::python; + +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + +PYBIND11_MODULE(_mlirLinalg, m) { + m.doc() = "MLIR Linalg Dialect APIs"; + + m.def( + "is_builtin", + [](PyDialect &dialect, std::string opName) { + return mlirLinalgIsBuiltinNamedOp( + py::cast(dialect.getDescriptor()).get(), + mlirStringRefCreateFromCString(opName.c_str())); + }, + py::arg("dialect"), py::arg("opname"), + "Return true if `name` is the name of a builtin named Linalg op."); +} diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/linalg/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/linalg/__init__.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/linalg/__init__.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from _mlirLinalg import * diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -10,5 +10,21 @@ #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +/// Return true if `name` is the name of a builtin named Linalg op. +bool mlirLinalgIsBuiltinNamedOp(MlirDialect linalgDialect, MlirStringRef name) { + return static_cast(unwrap(linalgDialect)) + ->getRegionBuilder(unwrap(name)) != nullptr; +} + +/// Apply the special region builder for the builtin named Linalg op. +/// Assert that `op` is a builtin named Linalg op. +void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, + MlirOperation op) { + mlir::linalg::LinalgDialect::RegionBuilderFunType fun = + static_cast(unwrap(linalgDialect)) + ->getRegionBuilder(unwrap(op)->getName().getStringRef()); + assert(fun && "Expected a builtin named Linalg op."); +} + MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, mlir::linalg::LinalgDialect) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -57,6 +57,36 @@ // LinalgDialect //===----------------------------------------------------------------------===// +/// Trait to check if T provides a `regionBuilder` method. +template +using has_region_builder = decltype(T::regionBuilder()); +template +using detect_has_region_builder = llvm::is_detected; + +/// SFINAE helper for single C++ class without a `regionBuilder` method (e.g. +/// an OpInterface). +template ::value>> +void addNamedOpBuilderImpl( + llvm::StringMap &map) { + // Do nothing. +} + +template ::value>, + typename = void> +void addNamedOpBuilderImpl( + llvm::StringMap &map) { + map.insert(OpType::getOperationName(), OpType::regionBuilder); +} + +template +void addNamedOpBuilders( + llvm::StringMap &map) { + (void)std::initializer_list{0, + (addNamedOpBuilderImpl(map), 0)...}; +} + void mlir::linalg::LinalgDialect::initialize() { addTypes(); addOperations< @@ -72,6 +102,12 @@ #include "mlir/Dialect/Linalg/IR/LinalgSparseOps.cpp.inc" >(); + // Fill the Linalg-specific OpName to RegionBuilder map. + addNamedOpBuilders< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >(namedStructuredOpRegionBuilders); + addInterfaces(); } diff --git a/mlir/test/Bindings/Python/dialects/linalg/ops.py b/mlir/test/Bindings/Python/dialects/linalg/ops.py --- a/mlir/test/Bindings/Python/dialects/linalg/ops.py +++ b/mlir/test/Bindings/Python/dialects/linalg/ops.py @@ -3,9 +3,9 @@ from mlir.ir import * from mlir.dialects import builtin from mlir.dialects import linalg +from mlir.dialects.linalg import linalg as linalg2 from mlir.dialects import std - def run(f): print("\nTEST:", f.__name__) f() @@ -94,7 +94,7 @@ def named_form(lhs, rhs): init_result = linalg.InitTensorOp([4, 8], f32) # CHECK: linalg.matmul - # TODO: prperly hook up the region. + # TODO: properly hook up the region. return linalg.matmul(lhs, rhs, outs=[init_result.result]) @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32), @@ -104,4 +104,4 @@ # CHECK: linalg.generic return linalg.matmul(lhs, rhs, outs=[init_result.result], emit_generic=True) - print(module) + module.operation.print(print_generic_op_form=True)