diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -43,6 +43,9 @@ /// Verify that `op` conforms to the ConvolutionOpInterface. LogicalResult verifyConvolutionInterface(Operation *op); +/// Verify that `op` conforms to the FillOpInterface. +LogicalResult verifyFillInterface(Operation *op); + /// Verify that `op` conforms to the invariants of StructuredOpInterface LogicalResult verifyStructuredOpInterface(Operation *op); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -132,6 +132,50 @@ ]; } +def LinalgFillOpInterface : OpInterface<"FillOpInterface"> { + let description = [{ + A fill operation is defined in general terms: + 1. Has a scalar `value` operand. + 2. Has one `output` operand. + }]; + let cppNamespace = "::mlir::linalg"; + let verify = [{ return detail::verifyFillInterface($_op); }]; + let methods = [ + InterfaceMethod< + /*desc=*/"Return the fill value.", + /*retTy=*/"Value", + /*methodName=*/"value", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getOperation()->getOperand(0); + }] + >, + InterfaceMethod< + /*desc=*/"Return the output operand.", + /*retTy=*/"Value", + /*methodName=*/"output", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getOperation()->getOperand(1); + }] + >, + InterfaceMethod< + /*desc=*/"Return the result.", + /*retTy=*/"Value", + /*methodName=*/"result", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if ($_op.getOperation()->getResults().empty()) + return nullptr; + return $_op.getOperation()->getResults().front(); + }] + >, + ]; +} + // The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface. def LinalgStructuredInterface : OpInterface<"LinalgOp"> { let cppNamespace = "::mlir::linalg"; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -408,6 +408,44 @@ } return success(); } + +//===----------------------------------------------------------------------===// +// FillOpInterface implementation +//===----------------------------------------------------------------------===// + +enum class MatchFillResult { + Success = 0, + NotLinalgOp, + WrongNumOperands, + NotScalarInput +}; + +static MatchFillResult isFillInterfaceImpl(Operation *op) { + auto linalgOp = dyn_cast(op); + if (!linalgOp) + return MatchFillResult::NotLinalgOp; + if (linalgOp.getNumInputs() != 1 || linalgOp.getNumOutputs() != 1) + return MatchFillResult::WrongNumOperands; + + OpOperand *value = linalgOp.getInputOperand(0); + if (!linalgOp.isScalar(value)) + return MatchFillResult::NotScalarInput; + + return MatchFillResult::Success; +} + +LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) { + auto res = isFillInterfaceImpl(op); + if (res == MatchFillResult::NotLinalgOp) + return op->emitError("expected a LinalgOp"); + if (res == MatchFillResult::WrongNumOperands) + return op->emitError("expected op with 1 input and 1 output"); + if (res == MatchFillResult::NotScalarInput) + return op->emitError("expected op with scalar input"); + + return success(); +} + //===----------------------------------------------------------------------===// // StructuredOpInterface implementation //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -686,6 +686,7 @@ ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface") +FillOpInterface = OpInterfaceDef("LinalgFillOpInterface") class OpMetadataDef(YAMLObject): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -671,6 +671,7 @@ accesses only and is thus rank polymorphic. Numeric casting is performed on the value operand, promoting it to the same data type as the output. """ + implements(FillOpInterface) O[None] = TypeFn.cast_signed(U, value) diff --git a/mlir/test/Dialect/Linalg/fill-interface-invalid.mlir b/mlir/test/Dialect/Linalg/fill-interface-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fill-interface-invalid.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s + +func @test_fill_op_not_linalg_op(%arg0 : f32, %arg1 : tensor) + -> tensor { + // expected-error @+1 {{expected a LinalgOp}} + %0 = "test.fill_op_not_linalg_op"(%arg0, %arg1) + : (f32, tensor) -> tensor + return %0 : tensor +} + +// ----- + +#map0 = affine_map<(d0) -> ()> +#map1 = affine_map<(d0) -> (d0)> +func @test_fill_op_wrong_num_operands(%arg0 : f32, %arg1 : tensor) + -> tensor { + // expected-error @+1 {{expected op with 1 input and 1 output}} + %0 = test.linalg_fill_op { + indexing_maps = [#map0, #map0, #map1], + iterator_types = ["parallel"]} + ins(%arg0, %arg0 : f32, f32) outs(%arg1 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): + linalg.yield %arg2 : f32 + } -> tensor + return %0 : tensor +} + +// ----- + +#map1 = affine_map<(d0) -> (d0)> +func @test_fill_op_non_scalar_input(%arg0 : tensor, + %arg1 : tensor) -> tensor { + // expected-error @+1 {{expected op with scalar input}} + %0 = test.linalg_fill_op { + indexing_maps = [#map1, #map1], + iterator_types = ["parallel"]} + ins(%arg0 : tensor) outs(%arg1 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32): + linalg.yield %arg2 : f32 + } -> tensor + return %0 : tensor +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2640,6 +2640,64 @@ }]; } +//===----------------------------------------------------------------------===// +// Test LinalgFillOpInterface. +//===----------------------------------------------------------------------===// + +def TestLinalgFillOpNotLinalgOp : TEST_Op<"fill_op_not_linalg_op", [ + LinalgFillOpInterface]> { + let arguments = (ins + AnyType:$value, AnyType:$output); + let results = (outs AnyRankedTensor:$result); +} + +def TestLinalgFillOp : + TEST_Op<"linalg_fill_op", [AttrSizedOperandSegments, SingleBlock, + LinalgStructuredInterface, LinalgFillOpInterface]> { + + let arguments = (ins Variadic:$inputs, + Variadic:$outputs); + let results = (outs Variadic:$results); + let regions = (region AnyRegion:$region); + + let assemblyFormat = [{ + attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)? + `outs` `(` $outputs `:` type($outputs) `)` + $region (`->` type($results)^)? + }]; + + let extraClassDeclaration = [{ + bool hasIndexSemantics() { return false; } + + static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block, + mlir::ArrayRef attrs) { + b.create(block.getArguments().back()); + } + + static std::function)> + getRegionBuilder() { + return ®ionBuilder; + } + + mlir::ArrayAttr iterator_types() { + return getOperation()->getAttrOfType("iterator_types"); + } + + mlir::ArrayAttr indexing_maps() { + return getOperation()->getAttrOfType("indexing_maps"); + } + + std::string getLibraryCallName() { + return ""; + } + + // To conform with interface requirement on operand naming. + mlir::ValueRange inputs() { return getInputs(); } + mlir::ValueRange outputs() { return getOutputs(); } + }]; +} + //===----------------------------------------------------------------------===// // Test Ops with Default-Valued String Attributes //===----------------------------------------------------------------------===//