diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -18,9 +18,9 @@ #endif /// Apply the special region builder for the builtin named Linalg op. -/// Assert that `op` is a builtin named Linalg op. +/// Assert that `mlirOp` is a builtin named Linalg op. MLIR_CAPI_EXPORTED void -mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op); +mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp); MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -7,24 +7,17 @@ //===----------------------------------------------------------------------===// #include "Dialects.h" -#include "IRModule.h" #include "mlir-c/Dialect/Linalg.h" #include "mlir-c/IR.h" - -// TODO: Port this to operate only on the public PybindAdaptors.h -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" namespace py = pybind11; -using namespace mlir; -using namespace mlir::python; void mlir::python::populateDialectLinalgSubmodule(py::module m) { m.def( "fill_builtin_region", - [](PyDialectDescriptor &dialect, PyOperation &op) { - mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get()); - }, - py::arg("dialect"), py::arg("op"), + [](MlirOperation op) { mlirLinalgFillBuiltinNamedOpRegion(op); }, + py::arg("op"), "Fill the region for `op`, which is assumed to be a builtin named Linalg " "op."); } diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -15,20 +15,19 @@ /// Apply the special region builder for the builtin named Linalg op. /// Assert that `op` is a builtin named Linalg op. -void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, - MlirOperation mlirOp) { +void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) { Operation *op = unwrap(mlirOp); - + auto linalgOp = cast(op); + auto *dialect = static_cast(linalgOp->getDialect()); LinalgDialect::RegionBuilderFunType fun = - static_cast(unwrap(linalgDialect)) - ->getRegionBuilder(op->getName().getStringRef()); + dialect->getRegionBuilder(op->getName().getStringRef()); + assert(fun && "Expected a builtin named Linalg op."); assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region"); assert(op->getRegion(0).getBlocks().empty() && "Expected Linalg op with 0 blocks"); SmallVector argTypes; - auto linalgOp = cast(op); for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) argTypes.push_back(getElementTypeOrSelf(opOperand->get().getType())); diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -34,8 +34,7 @@ loc=loc, ip=ip) OpView.__init__(self, op) - linalgDialect = Context.current.get_dialect_descriptor("linalg") - fill_builtin_region(linalgDialect, self.operation) + fill_builtin_region(self.operation) class InitTensorOp: """Extends the linalg.init_tensor op.""" diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -173,8 +173,7 @@ f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") named_op = getattr(linalg, op_class_name)(ins, outs, result_types) - linalgDialect = ctx.get_dialect_descriptor("linalg") - fill_builtin_region(linalgDialect, named_op.operation) + fill_builtin_region(named_op.operation) # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps # attribute that the non-yaml path does not. The non-yaml path hardcodes the # indexing_maps in C++ directly.