diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -17,6 +17,11 @@ extern "C" { #endif +/// Apply the special region builder for the builtin named Linalg op. +/// Assert that `op` is a builtin named Linalg op. +MLIR_CAPI_EXPORTED void +mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op); + MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); #ifdef __cplusplus diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -37,6 +37,14 @@ let dependentDialects = [ "AffineDialect", "StandardOpsDialect", "tensor::TensorDialect" ]; + let extraClassDeclaration = [{ + using RegionBuilderFunType = llvm::function_ref; + RegionBuilderFunType getRegionBuilder(StringRef name) { + return namedStructuredOpRegionBuilders.lookup(name); + } + private: + llvm::StringMap namedStructuredOpRegionBuilders; + }]; } // Whether a type is a RangeType. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -14,6 +14,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Types.h" +#include "llvm/ADT/StringMap.h" #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc" diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -69,6 +69,7 @@ INSTALL_DIR python SOURCES + DialectLinalg.cpp MainModule.cpp IRAffine.cpp IRAttributes.cpp diff --git a/mlir/lib/Bindings/Python/DialectLinalg.h b/mlir/lib/Bindings/Python/DialectLinalg.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectLinalg.h @@ -0,0 +1,22 @@ +//===- DialectLinalg.h - Linalg dialect submodule of pybind module --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_DIALECTLINALG_H +#define MLIR_BINDINGS_PYTHON_DIALECTLINALG_H + +#include "PybindUtils.h" + +namespace mlir { +namespace python { + +void populateDialectLinalgSubmodule(pybind11::module &m); + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_DIALECTLINALG_H diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -0,0 +1,34 @@ +//===- DialectLinalg.cpp - Pybind module for Linalg dialect API support --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" +#include "mlir-c/Dialect/Linalg.h" +#include "mlir-c/IR.h" + +#include + +namespace py = pybind11; +using namespace mlir; +using namespace mlir::python; + +namespace mlir { +namespace python { + +void populateDialectLinalgSubmodule(py::module &m) { + m.def( + "fill_builtin_region", + [](PyDialectDescriptor &dialect, PyOperation &op) { + return mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get()); + }, + py::arg("dialect"), py::arg("op"), + "Fill the region for `op`, which is assumed to be a builtin named Linalg " + "op."); +} + +} // namespace python +} // namespace mlir diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -10,6 +10,7 @@ #include "PybindUtils.h" +#include "DialectLinalg.h" #include "ExecutionEngine.h" #include "Globals.h" #include "IRModule.h" @@ -225,4 +226,9 @@ auto executionEngineModule = m.def_submodule("execution_engine", "MLIR JIT Execution Engine"); populateExecutionEngineSubmodule(executionEngineModule); + + // Define and populate Linalg submodule. + auto dialectsModule = m.def_submodule("dialects"); + auto linalgModule = dialectsModule.def_submodule("linalg"); + populateDialectLinalgSubmodule(linalgModule); } diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -61,11 +61,10 @@ raise NotImplementedError( f"Emission of composite linalg ops not supported: {op_configs}") - # TODO: this file should probably not be called dsl.py but rather is a client - # of the dsl.py. - from .... import linalg as linalg_ops - emit_generic = (emit_generic or - (not self.model.metadata.cpp_class_name in linalg_ops.__dict__.keys())) + ctx = ir.Context.current + linalgDialect = ctx.get_dialect_descriptor("linalg") + fully_qualified_name = 'linalg.' + self.op_name + emit_generic = (emit_generic or not ctx.is_registered_operation(fully_qualified_name)) op_config = op_configs[0] if op_config.structured_op: diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -7,6 +7,9 @@ from mlir.ir import * from mlir.dialects import linalg from mlir.dialects import std +# TODO: resolve name collision for Linalg functionality that is injected inside +# the _mlir.dialects.linalg directly via pybind. +from _mlir.dialects.linalg import fill_builtin_region from .scalar_expr import * from .config import * @@ -16,7 +19,6 @@ "emit_named_structured_op", ] - def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: Value): @@ -97,11 +99,18 @@ type_mapping, indexing_maps_attr, iterator_types_attr = \ prepare_common_structured_op(op_config, *ins, outs = outs) - if not op_class_name in linalg.__dict__.keys(): + # If we get here, there must exist a builtin class `op_class_name`. + ctx = Context.current + fully_qualified_name = 'linalg.' + op_name + if (not ctx.is_registered_operation(fully_qualified_name) or + not op_class_name in linalg.__dict__.keys()): raise NotImplementedError( f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") named_op = getattr(linalg, op_class_name)(ins, outs, out_types) + linalgDialect = ctx.get_dialect_descriptor("linalg") + fill_builtin_region(linalgDialect, named_op.operation) + if len(out_arg_defs) == 1: return named_op.result else: diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -10,5 +10,30 @@ #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, - mlir::linalg::LinalgDialect) +using namespace mlir; +using namespace mlir::linalg; + +/// Apply the special region builder for the builtin named Linalg op. +/// Assert that `op` is a builtin named Linalg op. +void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, + MlirOperation mlirOp) { + Operation *op = unwrap(mlirOp); + LinalgDialect::RegionBuilderFunType fun = + static_cast(unwrap(linalgDialect)) + ->getRegionBuilder(op->getName().getStringRef()); + assert(fun && "Expected a builtin named Linalg op."); + assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region"); + assert(op->getRegion(0).getBlocks().empty() && + "Expected Linalg op with 0 blocks"); + SmallVector argTypes; + auto linalgOp = cast(op); + for (auto t : linalgOp.getShapedOperandTypes()) + argTypes.push_back(getElementTypeOrSelf(t)); + OpBuilder b(op->getContext()); + Region ®ion = op->getRegion(0); + Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes); + // TODO: allow captures. + fun(*body, ValueRange{}); +} + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -57,6 +57,38 @@ // LinalgDialect //===----------------------------------------------------------------------===// +/// Trait to check if T provides a `regionBuilder` method. +template +using has_region_builder = decltype(T::regionBuilder); +template +using detect_has_region_builder = llvm::is_detected; + +/// SFINAE helper for single C++ class without a `regionBuilder` method (e.g. +/// an OpInterface). +template ::value>> +void addNamedOpBuilderImpl( + llvm::StringMap &map) { + // Do nothing. +} + +template ::value>, + typename = void> +void addNamedOpBuilderImpl( + llvm::StringMap &map) { + map.insert(std::make_pair( + OpType::getOperationName(), + static_cast(OpType::regionBuilder))); +} + +template +void addNamedOpBuilders( + llvm::StringMap &map) { + (void)std::initializer_list{0, + (addNamedOpBuilderImpl(map), 0)...}; +} + void mlir::linalg::LinalgDialect::initialize() { addTypes(); addOperations< @@ -72,6 +104,12 @@ #include "mlir/Dialect/Linalg/IR/LinalgSparseOps.cpp.inc" >(); + // Fill the Linalg-specific OpName to RegionBuilder map. + addNamedOpBuilders< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >(namedStructuredOpRegionBuilders); + addInterfaces(); } diff --git a/mlir/test/Bindings/Python/dialects/linalg/ops.py b/mlir/test/Bindings/Python/dialects/linalg/ops.py --- a/mlir/test/Bindings/Python/dialects/linalg/ops.py +++ b/mlir/test/Bindings/Python/dialects/linalg/ops.py @@ -5,7 +5,6 @@ from mlir.dialects import linalg from mlir.dialects import std - def run(f): print("\nTEST:", f.__name__) f() @@ -82,9 +81,9 @@ # CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>) print(module) -# CHECK-LABEL: TEST: testNamedStructuredOp +# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm @run -def testNamedStructuredOp(): +def testNamedStructuredOpCustomForm(): with Context() as ctx, Location.unknown(): module = Module.create() f32 = F32Type.get() @@ -93,10 +92,45 @@ RankedTensorType.get((16, 8), f32)) def named_form(lhs, rhs): init_result = linalg.InitTensorOp([4, 8], f32) - # CHECK: linalg.matmul - # TODO: prperly hook up the region. + # First check the named form with custom format + # CHECK: linalg.matmul + # CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>) + # CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>) + # CHECK-SAME: -> tensor<4x8xf32> + # CHECK-NEXT: return return linalg.matmul(lhs, rhs, outs=[init_result.result]) + print(module) + +# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm +@run +def testNamedStructuredOpGenericForm(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32), + RankedTensorType.get((16, 8), f32)) + def named_form(lhs, rhs): + init_result = linalg.InitTensorOp([4, 8], f32) + # CHECK: "linalg.matmul"(%{{.*}}) + # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32): + # CHECK-NEXT: std.mulf{{.*}} (f32, f32) -> f32 + # CHECK-NEXT: std.addf{{.*}} (f32, f32) -> f32 + # CHECK-NEXT: linalg.yield{{.*}} (f32) -> () + # CHECK-NEXT: {operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : + # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> + return linalg.matmul(lhs, rhs, outs=[init_result.result]) + + module.operation.print(print_generic_op_form=True) + +# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp +@run +def testNamedStructuredAsGenericOp(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)) def generic_form(lhs, rhs):