diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -148,8 +148,9 @@ } def FillOp : LinalgStructured_Op<"fill", []> { - let arguments = (ins AnyStridedMemRef:$output, + let arguments = (ins AnyShaped:$output, AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value); + let results = (outs Optional:$result); let extraClassDeclaration = libraryCallName # [{ ValueRange inputs() { return {}; } ValueRange outputs() { return getOperands().take_front(); } @@ -174,6 +175,14 @@ } }]; + let assemblyFormat = [{ + `(` operands `)` attr-dict `:` type(operands) (`->` type($result)^)? + }]; + + let builders = [ + OpBuilderDAG<(ins "Value":$output, "Value":$value)> + ]; + let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -220,6 +220,16 @@ // LinalgOps.td), we define an overloaded `print` function and a // parse`className` function. +//===----------------------------------------------------------------------===// +// FillOp +//===----------------------------------------------------------------------===// + +void FillOp::build(OpBuilder &builder, OperationState &result, Value output, + Value value) { + build(builder, result, output.getType().dyn_cast(), output, + value); +} + //===----------------------------------------------------------------------===// // GenericOps //===----------------------------------------------------------------------===// @@ -1726,6 +1736,10 @@ auto fillType = op.value().getType(); if (viewType.getElementType() != fillType) return op.emitOpError("expects fill type to match view elemental type"); + if (!op.getNumResults() && !viewType.isa()) { + return op.emitOpError( + "expected fill op with no result value to use memref type"); + } return success(); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -659,3 +659,41 @@ } : tensor to tensor return %0 : tensor } + +// ----- + +func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f32) +{ + %0 = linalg.init_tensor [%arg0, %arg1] : tensor + // expected-error @+1 {{expected fill op with no result value to use memref type}} + linalg.fill(%0, %arg2) : tensor, f32 +} + +// ----- + +func @illegal_fill_memref_with_return(%arg0 : memref, %arg1 : f32) -> memref +{ + // expected-error @+1 {{unexpected #results > #outputs}} + %0 = linalg.fill(%arg0, %arg1) : memref, f32 -> memref + return %0 : memref +} + +// ----- + +func @illegal_fill_memref_with_tensor_return + (%arg0 : memref, %arg1 : f32) -> tensor +{ + // expected-error @+1 {{unexpected #results > #outputs}} + %0 = linalg.fill(%arg0, %arg1) : memref, f32 -> tensor + return %0 : tensor +} + +// ----- + +func @illegal_fill_tensor_with_memref_return + (%arg0 : tensor, %arg1 : f32) -> memref +{ + // expected-error @+1 {{expected type of operand #0 ('tensor') to match type of corresponding result ('memref')}} + %0 = linalg.fill(%arg0, %arg1) : tensor, f32 -> memref + return %0 : memref +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -805,3 +805,12 @@ // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)> // CHECK: func @legal_collapsing_reshape_dynamic_memref // CHECK: linalg.reshape %{{.+}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] + +// ----- + +func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor { + %0 = linalg.init_tensor [%arg0, %arg1] : tensor + %1 = linalg.fill(%0, %arg2) : tensor, f32 -> tensor + return %1 : tensor +} +// CHECK: %{{.+}} = linalg.fill(%{{.+}}, %{{.+}}) : tensor, f32 -> tensor diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -41,7 +41,7 @@ %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>, - affine_map<(d0, d1, d2) -> (d2, d1, d0)>], + affine_map<(d0, d1, d2) -> (d2, d1, d0)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%3 : tensor) { @@ -88,7 +88,7 @@ %4 = linalg.indexed_generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>, - affine_map<(d0, d1, d2) -> (d2, d1, d0)>], + affine_map<(d0, d1, d2) -> (d2, d1, d0)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%3 : tensor) { @@ -120,3 +120,26 @@ // CHECK: scf.yield %[[TD1]] // CHECK: } // CHECK: return %[[TD0]] + +// ----- + +func @fill_tensors(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor { + %0 = linalg.init_tensor [%arg0, %arg1] : tensor + %1 = linalg.fill(%0, %arg2) : tensor, f32 -> tensor + return %1 : tensor +} +// CHECK: func @fill_tensors +// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-z0-9_]+]] +// CHECK-SAME: iter_args(%[[ARG4:.+]] = %[[INIT]]) -> (tensor) { +// CHECK: %[[YIELD_1:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]] +// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor) { +// CHECK: %[[FILL_TILE:.+]] = subtensor %[[ARG6]][%[[IV0]], %[[IV1]]] +// CHECK: %[[RESULT_TILE:.+]] = linalg.fill(%[[FILL_TILE]], %{{.+}}) +// CHECK: %[[YIELD_2:.+]] = subtensor_insert %[[RESULT_TILE]] +// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]] +// CHECK: scf.yield %[[YIELD_2]] +// CHECK: } +// CHECK: scf.yield %[[YIELD_1]] +// CHECK: } +// CHECK: return %[[RESULT]]