diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -43,6 +43,9 @@ /// Verify that `op` conforms to the ConvolutionOpInterface. LogicalResult verifyConvolutionInterface(Operation *op); +/// Verify that `op` conforms to the FillOpInterface. +LogicalResult verifyFillInterface(Operation *op); + /// Verify that `op` conforms to the invariants of StructuredOpInterface LogicalResult verifyStructuredOpInterface(Operation *op); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -132,6 +132,50 @@ ]; } +def LinalgFillOpInterface : OpInterface<"FillOpInterface"> { + let description = [{ + A fill operation is defined in general terms: + 1. Has a scalar `value` operand. + 2. Has one `output` operand. + }]; + let cppNamespace = "::mlir::linalg"; + let verify = [{ return detail::verifyFillInterface($_op); }]; + let methods = [ + InterfaceMethod< + /*desc=*/"Return the fill value.", + /*retTy=*/"Value", + /*methodName=*/"value", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getOperation()->getOperand(0); + }] + >, + InterfaceMethod< + /*desc=*/"Return the output operand.", + /*retTy=*/"Value", + /*methodName=*/"output", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getOperation()->getOperand(1); + }] + >, + InterfaceMethod< + /*desc=*/"Return the result.", + /*retTy=*/"Value", + /*methodName=*/"result", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if ($_op.getOperation()->getResults().empty()) + return nullptr; + return $_op.getOperation()->getResults().front(); + }] + >, + ]; +} + // The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface. def LinalgStructuredInterface : OpInterface<"LinalgOp"> { let cppNamespace = "::mlir::linalg"; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -2875,6 +2875,8 @@ Works for arbitrary ranked output tensors since the operation performs scalar accesses only and is thus rank polymorphic. Numeric casting is performed on the value operand, promoting it to the same data type as the output. + implements: + - LinalgFillOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -408,6 +408,44 @@ } return success(); } + +//===----------------------------------------------------------------------===// +// FillOpInterface implementation +//===----------------------------------------------------------------------===// + +enum class MatchFillResult { + Success = 0, + NotLinalgOp, + WrongNumOperands, + NotScalarInput +}; + +static MatchFillResult isFillInterfaceImpl(Operation *op) { + auto linalgOp = dyn_cast(op); + if (!linalgOp) + return MatchFillResult::NotLinalgOp; + if (linalgOp.getNumInputs() != 1 || linalgOp.getNumOutputs() != 1) + return MatchFillResult::WrongNumOperands; + + OpOperand *value = linalgOp.getInputOperand(0); + if (!linalgOp.isScalar(value)) + return MatchFillResult::NotScalarInput; + + return MatchFillResult::Success; +} + +LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) { + auto res = isFillInterfaceImpl(op); + if (res == MatchFillResult::NotLinalgOp) + return op->emitError("expected a LinalgOp"); + if (res == MatchFillResult::WrongNumOperands) + return op->emitError("expected op with 1 input and 1 output"); + if (res == MatchFillResult::NotScalarInput) + return op->emitError("expected op with scalar input"); + + return success(); +} + //===----------------------------------------------------------------------===// // StructuredOpInterface implementation //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -686,6 +686,7 @@ ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface") +FillOpInterface = OpInterfaceDef("LinalgFillOpInterface") class OpMetadataDef(YAMLObject): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -671,6 +671,7 @@ accesses only and is thus rank polymorphic. Numeric casting is performed on the value operand, promoting it to the same data type as the output. """ + implements(FillOpInterface) O[None] = TypeFn.cast_signed(U, value) diff --git a/mlir/test/Dialect/Linalg/fill-interface-invalid.mlir b/mlir/test/Dialect/Linalg/fill-interface-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fill-interface-invalid.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s + +func @test_fill_op_not_linalg_op(%arg0 : f32, %arg1 : tensor) + -> tensor { + // expected-error @+1 {{expected a LinalgOp}} + %0 = "test.fill_op_not_linalg_op"(%arg0, %arg1) + : (f32, tensor) -> tensor + return %0 : tensor +} + +// ----- + +#map0 = affine_map<(d0) -> ()> +#map1 = affine_map<(d0) -> (d0)> +func @test_fill_op_wrong_num_operands(%arg0 : f32, %arg1 : tensor) + -> tensor { + // expected-error @+1 {{expected op with 1 input and 1 output}} + %0 = test.linalg_fill_op { + indexing_maps = [#map0, #map0, #map1], + iterator_types = ["parallel"]} + ins(%arg0, %arg0 : f32, f32) outs(%arg1 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): + linalg.yield %arg2 : f32 + } -> tensor + return %0 : tensor +} + +// ----- + +#map1 = affine_map<(d0) -> (d0)> +func @test_fill_op_non_scalar_input(%arg0 : tensor, + %arg1 : tensor) -> tensor { + // expected-error @+1 {{expected op with scalar input}} + %0 = test.linalg_fill_op { + indexing_maps = [#map1, #map1], + iterator_types = ["parallel"]} + ins(%arg0 : tensor) outs(%arg1 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32): + linalg.yield %arg2 : f32 + } -> tensor + return %0 : tensor +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2640,6 +2640,64 @@ }]; } +//===----------------------------------------------------------------------===// +// Test LinalgFillOpInterface. +//===----------------------------------------------------------------------===// + +def TestLinalgFillOpNotLinalgOp : TEST_Op<"fill_op_not_linalg_op", [ + LinalgFillOpInterface]> { + let arguments = (ins + AnyType:$value, AnyType:$output); + let results = (outs AnyRankedTensor:$result); +} + +def TestLinalgFillOp : + TEST_Op<"linalg_fill_op", [AttrSizedOperandSegments, SingleBlock, + LinalgStructuredInterface, LinalgFillOpInterface]> { + + let arguments = (ins Variadic:$inputs, + Variadic:$outputs); + let results = (outs Variadic:$results); + let regions = (region AnyRegion:$region); + + let assemblyFormat = [{ + attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)? + `outs` `(` $outputs `:` type($outputs) `)` + $region (`->` type($results)^)? + }]; + + let extraClassDeclaration = [{ + bool hasIndexSemantics() { return false; } + + static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block, + mlir::ArrayRef attrs) { + b.create(block.getArguments().back()); + } + + static std::function)> + getRegionBuilder() { + return ®ionBuilder; + } + + mlir::ArrayAttr iterator_types() { + return getOperation()->getAttrOfType("iterator_types"); + } + + mlir::ArrayAttr indexing_maps() { + return getOperation()->getAttrOfType("indexing_maps"); + } + + std::string getLibraryCallName() { + return ""; + } + + // To conform with interface requirement on operand naming. + mlir::ValueRange inputs() { return getInputs(); } + mlir::ValueRange outputs() { return getOutputs(); } + }]; +} + //===----------------------------------------------------------------------===// // Test Ops with Default-Valued String Attributes //===----------------------------------------------------------------------===//