diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -17,6 +17,14 @@ extern "C" { #endif +/// Return true if `name` is the name of a builtin named Linalg op. +bool mlirLinalgIsBuiltinNamedOp(MlirDialect linalgDialect, MlirStringRef name); + +/// Apply the special region builder for the builtin named Linalg op. +/// Assert that `op` is a builtin named Linalg op. +void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, + MlirOperation op); + MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); #ifdef __cplusplus diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -37,6 +37,14 @@ let dependentDialects = [ "AffineDialect", "StandardOpsDialect", "tensor::TensorDialect" ]; + let extraClassDeclaration = [{ + using RegionBuilderFunType = std::function; + RegionBuilderFunType getRegionBuilder(StringRef name) { + return namedStructuredOpRegionBuilders.lookup(name); + } + private: + llvm::StringMap namedStructuredOpRegionBuilders; + }]; } // Whether a type is a RangeType. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -14,6 +14,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Types.h" +#include "llvm/ADT/StringMap.h" #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc" diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -69,6 +69,7 @@ INSTALL_DIR python SOURCES + LinalgBuiltin.cpp MainModule.cpp IRAffine.cpp IRAttributes.cpp @@ -77,6 +78,8 @@ PybindUtils.cpp Pass.cpp ExecutionEngine.cpp + LINK_LIBS + MLIRCAPILinalg ) add_dependencies(MLIRBindingsPythonExtension MLIRCoreBindingsPythonExtension) diff --git a/mlir/lib/Bindings/Python/LinalgBuiltin.h b/mlir/lib/Bindings/Python/LinalgBuiltin.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/LinalgBuiltin.h @@ -0,0 +1,22 @@ +//===- LinalgBuiltin.h - Linalg builtin submodule of pybind module --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_LINALGBUILTIN_H +#define MLIR_BINDINGS_PYTHON_LINALGBUILTIN_H + +#include "PybindUtils.h" + +namespace mlir { +namespace python { + +void populateLinalgBuiltinSubmodule(pybind11::module &m); + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_LINALGBUILTIN_H diff --git a/mlir/lib/Bindings/Python/LinalgBuiltin.cpp b/mlir/lib/Bindings/Python/LinalgBuiltin.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/LinalgBuiltin.cpp @@ -0,0 +1,42 @@ +//===- LinalgBuiltin.cpp - Pybind module for Linalg builting API support --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" +#include "mlir-c/Dialect/Linalg.h" +#include "mlir-c/IR.h" + +#include + +namespace py = pybind11; +using namespace mlir; +using namespace mlir::python; + +namespace mlir { +namespace python { + +void populateLinalgBuiltinSubmodule(py::module &m) { + m.def( + "is_builtin", + [](PyDialectDescriptor &dialect, std::string opName) { + return mlirLinalgIsBuiltinNamedOp( + dialect.get(), mlirStringRefCreateFromCString(opName.c_str())); + }, + py::arg("dialect"), py::arg("opname"), + "Return true if `name` is the name of a builtin named Linalg op."); + m.def( + "fill_builtin_region", + [](PyDialectDescriptor &dialect, PyOperation &op) { + return mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get()); + }, + py::arg("dialect"), py::arg("op"), + "Fill the region for `op`, which is assumed to be a builtin named Linalg " + "op."); +} + +} // namespace python +} // namespace mlir \ No newline at end of file diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -13,6 +13,7 @@ #include "ExecutionEngine.h" #include "Globals.h" #include "IRModule.h" +#include "LinalgBuiltin.h" #include "Pass.h" namespace py = pybind11; @@ -225,4 +226,9 @@ auto executionEngineModule = m.def_submodule("execution_engine", "MLIR JIT Execution Engine"); populateExecutionEngineSubmodule(executionEngineModule); + + // Define and populate Linalg submodule. + auto linalgBuiltinModule = + m.def_submodule("linalg_builtin", "MLIR Linalg Builtin Helper Bindings"); + populateLinalgBuiltinSubmodule(linalgBuiltinModule); } diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -10,6 +10,7 @@ import threading from mlir import ir +from _mlir import linalg_builtin from .comprehension import * from .config import * from .emitter import * @@ -61,15 +62,13 @@ raise NotImplementedError( f"Emission of composite linalg ops not supported: {op_configs}") - # TODO: this file should probably not be called dsl.py but rather is a client - # of the dsl.py. - from .... import linalg as linalg_ops - emit_generic = (emit_generic or - (not self.model.metadata.cpp_class_name in linalg_ops.__dict__.keys())) + linalgDialect = ir.Context.current.get_dialect_descriptor("linalg") + fully_qualified_name = 'linalg.' + self.op_name + emit_generic = (emit_generic or not linalg_builtin.is_builtin(linalgDialect, fully_qualified_name)) op_config = op_configs[0] if op_config.structured_op: - if emit_generic: + if emit_generic is True: return emit_generic_structured_op(op_config.structured_op, *args, **kwargs) else: diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -5,6 +5,7 @@ from typing import Dict, Sequence from mlir.ir import * +from _mlir import linalg_builtin from mlir.dialects import linalg from mlir.dialects import std @@ -16,7 +17,6 @@ "emit_named_structured_op", ] - def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: Value): @@ -97,11 +97,17 @@ type_mapping, indexing_maps_attr, iterator_types_attr = \ prepare_common_structured_op(op_config, *ins, outs = outs) + # If we get here, there must exist a builtin class `op_class_name`. if not op_class_name in linalg.__dict__.keys(): raise NotImplementedError( f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") named_op = getattr(linalg, op_class_name)(ins, outs, out_types) + + linalgDialect = Context.current.get_dialect_descriptor("linalg") + linalg_builtin.fill_builtin_region(linalgDialect, named_op.operation) + named_op.operation.print(print_generic_op_form=True) + if len(out_arg_defs) == 1: return named_op.result else: diff --git a/mlir/lib/Bindings/Python/mlir/linalg_builtin.py b/mlir/lib/Bindings/Python/mlir/linalg_builtin.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/linalg_builtin.py @@ -0,0 +1,8 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Simply a wrapper around the extension module of the same name. +from ._cext_loader import _reexport_cext +_reexport_cext("linalg_builtin", __name__) +del _reexport_cext diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -10,5 +10,42 @@ #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, - mlir::linalg::LinalgDialect) +using namespace mlir; +using namespace mlir::linalg; + +/// Return true if `name` is the name of a builtin named Linalg op. +bool mlirLinalgIsBuiltinNamedOp(MlirDialect linalgDialect, MlirStringRef name) { + return static_cast(unwrap(linalgDialect)) + ->getRegionBuilder(unwrap(name)) != nullptr; +} + +/// Apply the special region builder for the builtin named Linalg op. +/// Assert that `op` is a builtin named Linalg op. +void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, + MlirOperation mlirOp) { + Operation *op = unwrap(mlirOp); + LinalgDialect::RegionBuilderFunType fun = + static_cast(unwrap(linalgDialect)) + ->getRegionBuilder(op->getName().getStringRef()); + assert(fun && "Expected a builtin named Linalg op."); + assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region"); + assert(op->getRegion(0).getBlocks().empty() && + "Expected Linalg op with 0 blocks"); + // TODO: allow captures. + SmallVector argTypes; + // auto linalgOp = cast(op); + // for (auto t : linalgOp.getShapedOperandTypes()) + // argTypes.push_back(getElementTypeOrSelf(t)); + for (auto t : op->getOperandTypes()) { + t.dump(); + getElementTypeOrSelf(t).dump(); + argTypes.push_back(getElementTypeOrSelf(t)); + } + OpBuilder b(op->getContext()); + OpBuilder::InsertionGuard guard(b); + Region ®ion = op->getRegion(0); + Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes); + fun(*body, ValueRange{}); +} + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -57,6 +57,38 @@ // LinalgDialect //===----------------------------------------------------------------------===// +/// Trait to check if T provides a `regionBuilder` method. +template +using has_region_builder = decltype(T::regionBuilder); +template +using detect_has_region_builder = llvm::is_detected; + +/// SFINAE helper for single C++ class without a `regionBuilder` method (e.g. +/// an OpInterface). +template ::value>> +void addNamedOpBuilderImpl( + llvm::StringMap &map) { + // Do nothing. +} + +template ::value>, + typename = void> +void addNamedOpBuilderImpl( + llvm::StringMap &map) { + map.insert(std::make_pair( + OpType::getOperationName(), + static_cast(OpType::regionBuilder))); +} + +template +void addNamedOpBuilders( + llvm::StringMap &map) { + (void)std::initializer_list{0, + (addNamedOpBuilderImpl(map), 0)...}; +} + void mlir::linalg::LinalgDialect::initialize() { addTypes(); addOperations< @@ -72,6 +104,12 @@ #include "mlir/Dialect/Linalg/IR/LinalgSparseOps.cpp.inc" >(); + // Fill the Linalg-specific OpName to RegionBuilder map. + addNamedOpBuilders< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >(namedStructuredOpRegionBuilders); + addInterfaces(); } diff --git a/mlir/test/Bindings/Python/dialects/linalg/ops.py b/mlir/test/Bindings/Python/dialects/linalg/ops.py --- a/mlir/test/Bindings/Python/dialects/linalg/ops.py +++ b/mlir/test/Bindings/Python/dialects/linalg/ops.py @@ -5,7 +5,6 @@ from mlir.dialects import linalg from mlir.dialects import std - def run(f): print("\nTEST:", f.__name__) f() @@ -94,7 +93,7 @@ def named_form(lhs, rhs): init_result = linalg.InitTensorOp([4, 8], f32) # CHECK: linalg.matmul - # TODO: prperly hook up the region. + # TODO: properly hook up the region. return linalg.matmul(lhs, rhs, outs=[init_result.result]) @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32), @@ -104,4 +103,5 @@ # CHECK: linalg.generic return linalg.matmul(lhs, rhs, outs=[init_result.result], emit_generic=True) + module.operation.print(print_generic_op_form=True) print(module)