diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -148,8 +148,9 @@ } def FillOp : LinalgStructured_Op<"fill", []> { - let arguments = (ins AnyStridedMemRef:$output, + let arguments = (ins AnyShaped:$output, AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value); + let results = (outs Optional:$result); let extraClassDeclaration = libraryCallName # [{ ValueRange inputs() { return {}; } ValueRange outputs() { return getOperands().take_front(); } @@ -174,6 +175,13 @@ } }]; + let assemblyFormat = [{ + `(` operands `)` attr-dict `:` type(operands) (`->` type($result)^)?}]; + + let builders = [ + OpBuilderDAG<(ins "Value":$output, "Value":$value)> + ]; + let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -220,6 +220,16 @@ // LinalgOps.td), we define an overloaded `print` function and a // parse`className` function. +//===----------------------------------------------------------------------===// +// FillOp +//===----------------------------------------------------------------------===// + +void FillOp::build(OpBuilder &builder, OperationState &result, Value output, + Value value) { + build(builder, result, output.getType().dyn_cast(), output, + value); +} + //===----------------------------------------------------------------------===// // GenericOps //===----------------------------------------------------------------------===// @@ -1574,6 +1584,10 @@ auto fillType = op.value().getType(); if (viewType.getElementType() != fillType) return op.emitOpError("expects fill type to match view elemental type"); + if (!op.getNumResults() && !viewType.isa()) { + return op.emitOpError( + "expected fill op with no result value to use memref type"); + } return success(); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -617,3 +617,41 @@ memref into memref return %0 : memref } + +// ----- + +func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f32) +{ + %0 = linalg.init_tensor [%arg0, %arg1] : tensor + // expected-error @+1 {{expected fill op with no result value to use memref type}} + linalg.fill(%0, %arg2) : tensor, f32 +} + +// ----- + +func @illegal_fill_memref_with_return(%arg0 : memref, %arg1 : f32) -> memref +{ + // expected-error @+1 {{unexpected #results > #outputs}} + %0 = linalg.fill(%arg0, %arg1) : memref, f32 -> memref + return %0 : memref +} + +// ----- + +func @illegal_fill_memref_with_tensor_return + (%arg0 : memref, %arg1 : f32) -> tensor +{ + // expected-error @+1 {{unexpected #results > #outputs}} + %0 = linalg.fill(%arg0, %arg1) : memref, f32 -> tensor + return %0 : tensor +} + +// ----- + +func @illegal_fill_tensor_with_memref_return + (%arg0 : tensor, %arg1 : f32) -> memref +{ + // expected-error @+1 {{expected type of operand #0 ('tensor') to match type of corresponding result ('memref')}} + %0 = linalg.fill(%arg0, %arg1) : tensor, f32 -> memref + return %0 : memref +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -753,3 +753,12 @@ // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)> // CHECK: func @legal_collapsing_reshape_dynamic_memref // CHECK: linalg.reshape %{{.+}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] + +// ----- + +func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor { + %0 = linalg.init_tensor [%arg0, %arg1] : tensor + %1 = linalg.fill(%0, %arg2) : tensor, f32 -> tensor + return %1 : tensor +} +// CHECK: %{{.+}} = linalg.fill(%{{.+}}, %{{.+}}) : tensor, f32 -> tensor