diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -1,11 +1,11 @@ -//===-- mlir-c/Dialect/Linalg.h - C API for Linalg dialect --------*- C -*-===// +//===-- mlir-c/Dialect/Linalg.h - C API for Linalg dialect -------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM // Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -//===----------------------------------------------------------------------===// +//===---------------------------------------------------------------------===// #ifndef MLIR_C_DIALECT_LINALG_H #define MLIR_C_DIALECT_LINALG_H @@ -18,9 +18,11 @@ #endif /// Apply the special region builder for the builtin named Linalg op. +/// The list of `capture` MlirValue is passed as-is to the region builder. /// Assert that `op` is a builtin named Linalg op. MLIR_CAPI_EXPORTED void -mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op); +mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op, + intptr_t n, MlirValue const *mlirCaptures); MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -22,10 +22,15 @@ void populateDialectLinalgSubmodule(py::module &m) { m.def( "fill_builtin_region", - [](PyDialectDescriptor &dialect, PyOperation &op) { - return mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get()); + [](PyDialectDescriptor &dialect, PyOperation &op, py::list captures) { + llvm::SmallVector mlirOperands; + mlirOperands.reserve(captures.size()); + for (auto v : captures) + mlirOperands.push_back(py::cast(v)->get()); + mlirLinalgFillBuiltinNamedOpRegion( + dialect.get(), op.get(), mlirOperands.size(), mlirOperands.data()); }, - py::arg("dialect"), py::arg("op"), + py::arg("dialect"), py::arg("op"), py::arg("captures") = py::list(), "Fill the region for `op`, which is assumed to be a builtin named Linalg " "op."); } diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py --- a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py @@ -5,6 +5,47 @@ from typing import Optional, Sequence, Union from ..ir import * from ._ods_common import get_default_loc_context +# TODO: resolve name collision for Linalg functionality that is injected inside +# the _mlir.dialects.linalg directly via pybind. +from _mlir.dialects.linalg import fill_builtin_region + + +def isa(cls : Type, ty : Type): + try: + cls(ty) + return True + except ValueError: + return False + + +class FillOp: + """Extends the linalg.fill op.""" + + def __init__(self, + output: Value, + value: Value, + *, + loc=None, + ip=None): + results = [] + if isa(RankedTensorType, output.type): + results = [output.type] + op = self.build_generic(results=results, + operands=[output, value], + attributes=None, + loc=loc, + ip=ip) + OpView.__init__(self, op) + linalgDialect = Context.current.get_dialect_descriptor("linalg") + fill_builtin_region(linalgDialect, self.operation, [value]) + # TODO: self.result is None. When len(results) == 1 we expect it to be + # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug + # in the generator of _linalg_ops_gen.py where we have: + # ``` + # def result(self): + # return self.operation.results[0] \ + # if len(self.operation.results) > 1 else None + # ``` class InitTensorOp: diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -8,6 +8,7 @@ #include "mlir-c/Dialect/Linalg.h" #include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" using namespace mlir; @@ -16,8 +17,14 @@ /// Apply the special region builder for the builtin named Linalg op. /// Assert that `op` is a builtin named Linalg op. void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, - MlirOperation mlirOp) { + MlirOperation mlirOp, intptr_t n, + MlirValue const *mlirCaptures) { Operation *op = unwrap(mlirOp); + SmallVector captures; + captures.reserve(n); + for (unsigned idx = 0; idx < n; ++idx) + captures.push_back(unwrap(mlirCaptures[idx])); + LinalgDialect::RegionBuilderFunType fun = static_cast(unwrap(linalgDialect)) ->getRegionBuilder(op->getName().getStringRef()); @@ -25,15 +32,18 @@ assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region"); assert(op->getRegion(0).getBlocks().empty() && "Expected Linalg op with 0 blocks"); + SmallVector argTypes; auto linalgOp = cast(op); for (auto t : linalgOp.getShapedOperandTypes()) argTypes.push_back(getElementTypeOrSelf(t)); + OpBuilder b(op->getContext()); Region ®ion = op->getRegion(0); Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes); - // TODO: allow captures. - fun(*body, ValueRange{}); + b.setInsertionPointToStart(body); + mlir::edsc::ScopedContext scope(b, op->getLoc()); + fun(*body, captures); } MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) diff --git a/mlir/test/Bindings/Python/dialects/linalg/ops.py b/mlir/test/Bindings/Python/dialects/linalg/ops.py --- a/mlir/test/Bindings/Python/dialects/linalg/ops.py +++ b/mlir/test/Bindings/Python/dialects/linalg/ops.py @@ -38,6 +38,40 @@ print(module) +# CHECK-LABEL: TEST: testFill +@run +def testFill(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + # CHECK-LABEL: func @fill_tensor + # CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<12x?xf32> + # CHECK-NEXT: %[[CST:.*]] = constant 0.0{{.*}} : f32 + # CHECK-NEXT: %[[RES:.*]] = linalg.fill(%[[OUT]], %[[CST]]) : tensor<12x?xf32>, f32 -> tensor<12x?xf32> + # CHECK-NEXT: return %[[RES]] : tensor<12x?xf32> + @builtin.FuncOp.from_py_func( + RankedTensorType.get((12, -1), f32)) + def fill_tensor(out): + zero = std.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result + # TODO: FillOp.result is None. When len(results) == 1 we expect it to + # be results[0] as per _linalg_ops_gen.py. This seems like an + # orthogonal bug in the generator of _linalg_ops_gen.py. + return linalg.FillOp(output=out, value=zero).results[0] + + # CHECK-LABEL: func @fill_buffer + # CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32> + # CHECK-NEXT: %[[CST:.*]] = constant 0.0{{.*}} : f32 + # CHECK-NEXT: linalg.fill(%[[OUT]], %[[CST]]) : memref<12x?xf32>, f32 + # CHECK-NEXT: return + @builtin.FuncOp.from_py_func( + MemRefType.get((12, -1), f32)) + def fill_buffer(out): + zero = std.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result + linalg.FillOp(output=out, value=zero) + + print(module) + # CHECK-LABEL: TEST: testStructuredOpOnTensors @run