diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -2908,8 +2908,8 @@ scalar_arg: I --- !LinalgOpConfig metadata: !LinalgOpMetadata - name: fill_tensor - cpp_class_name: FillTensorOp + name: fill + cpp_class_name: FillOp doc: |- Fills the output tensor with the given value. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -43,71 +43,6 @@ }]; } -class LinalgStructured_Op props> - : LinalgStructuredBase_Op { - code structuredOpsDecls = structuredOpsBaseDecls # [{ - std::string getLibraryCallName() { - return generateLibraryCallName(getOperation()); - } - }]; - let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)"; -} - -//===----------------------------------------------------------------------===// -// Named Linalg ops, implemented as special configurations of generic ops. -//===----------------------------------------------------------------------===// - -def FillOp : LinalgStructured_Op<"fill", []> { - let arguments = (ins - AnyTypeOf<[AnyComplex, AnyFloat, AnySignlessInteger, AnyVector]>:$value, - AnyShaped:$output); - let results = (outs Optional:$result); - let regions = (region AnyRegion:$region); - let extraClassDeclaration = structuredOpsDecls # [{ - ValueRange inputs() { return getOperands().take_front(); } - ValueRange outputs() { return getOperands().take_back(); } - - // Rank-polymorphic. - // filling_value -> O(ivs) with parallel iterators. - ArrayAttr iterator_types() { - int64_t nPar = getRank(getOutputOperand(0)); - return Builder(getContext()).getStrArrayAttr( - SmallVector(nPar, getParallelIteratorTypeName())); - } - - ArrayAttr indexing_maps() { - MLIRContext *context = getContext(); - // filling_value -> O(ivs) - return Builder(getContext()).getAffineMapArrayAttr({ - AffineMap::get(getNumParallelLoops(), 0, {}, getContext()), - extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)}); - } - - static void regionBuilder(ImplicitLocOpBuilder &b, Block &block, - ArrayRef attrs); - static std::function)> - getRegionBuilder() { - return ®ionBuilder; - } - static unsigned getNumRegionArgs() { return 2; } - }]; - - let assemblyFormat = [{ - `(` $value `,` $output `)` attr-dict `:` - type($value) `,` type($output) (`->` type($result)^)? - custom($region, ref(type($value)), ref(type($output))) - }]; - - let builders = [ - OpBuilder<(ins "Value":$value, "Value":$output)> - ]; - - let hasFolder = 1; - let hasCanonicalizer = 1; - let hasVerifier = 1; -} - //===----------------------------------------------------------------------===// // Generic Linalg ops. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -856,8 +856,10 @@ op, "No initial value found for reduction operation"); auto fillValue = rewriter.create(loc, fillValueAttr); - auto filledTensor = - rewriter.create(loc, fillValue, initTensor).result(); + auto filledTensor = rewriter + .create(loc, ValueRange{fillValue}, + ValueRange{initTensor}) + .result(); SmallVector srcExprs; SmallVector dstExprs; @@ -1717,7 +1719,9 @@ Value zeroVal = rewriter.createOrFold( loc, rewriter.getZeroAttr(resultType.getElementType())); Value result = - rewriter.create(loc, zeroVal, init).getResult(0); + rewriter + .create(loc, ValueRange{zeroVal}, ValueRange{init}) + .result(); auto toOpFoldResult = [](Value v) -> OpFoldResult { auto op = v.getDefiningOp(); @@ -1989,7 +1993,9 @@ auto fillValueIdx = rewriter.create( loc, rewriter.getIntegerAttr(outElementTy, 0)); auto filledTensorIdx = - rewriter.create(loc, fillValueIdx, initTensorIdx) + rewriter + .create(loc, ValueRange{fillValueIdx}, + ValueRange{initTensorIdx}) .result(); // Second fill the output buffer for the running max. @@ -2007,7 +2013,9 @@ auto fillValueMax = rewriter.create(loc, fillValueMaxAttr); auto filledTensorMax = - rewriter.create(loc, fillValueMax, initTensorMax) + rewriter + .create(loc, ValueRange{fillValueMax}, + ValueRange{initTensorMax}) .result(); // We need to reduce along the arg-max axis, with parallel operations along diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -214,8 +214,10 @@ Value initTensor = rewriter.create( loc, filteredDims, resultTy.getShape(), resultETy); Value zero = rewriter.create(loc, resultZeroAttr); - Value zeroTensor = - rewriter.create(loc, zero, initTensor).getResult(0); + Value zeroTensor = rewriter + .create(loc, ValueRange{zero}, + ValueRange{initTensor}) + .result(); // Extract the attributes for convolution. llvm::SmallVector stride, dilation; @@ -401,8 +403,10 @@ Value initTensor = rewriter.create( loc, dynamicDims, linalgConvTy.getShape(), resultETy); Value zero = rewriter.create(loc, resultZeroAttr); - Value zeroTensor = - rewriter.create(loc, zero, initTensor).getResult(0); + Value zeroTensor = rewriter + .create(loc, ValueRange{zero}, + ValueRange{initTensor}) + .result(); Value biasInitTensor = rewriter.create( loc, dynamicDims, resultTy.getShape(), resultETy); @@ -493,8 +497,10 @@ Value zero = rewriter.create(loc, zeroAttr); auto initTensor = rewriter.create( loc, filteredDims, outputTy.getShape(), outputTy.getElementType()); - Value zeroTensor = - rewriter.create(loc, zero, initTensor).getResult(0); + Value zeroTensor = rewriter + .create(loc, ValueRange{zero}, + ValueRange{initTensor}) + .result(); if (!op.quantization_info()) { rewriter.replaceOpWithNewOp( op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()}, @@ -567,8 +573,10 @@ // When quantized, the input elemeny type is not the same as the output Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy); Value zero = rewriter.create(loc, resultZeroAttr); - Value zeroTensor = - rewriter.create(loc, zero, initTensor).getResult(0); + Value zeroTensor = rewriter + .create(loc, ValueRange{zero}, + ValueRange{initTensor}) + .result(); SmallVector permutation{1, 0}; auto permutationAttr = DenseIntElementsAttr::get( @@ -700,7 +708,10 @@ loc, dynamicDims, resultTy.getShape(), resultTy.getElementType()); Value filledInitTensor = - rewriter.create(loc, initialValue, initTensor).result(); + rewriter + .create(loc, ValueRange{initialValue}, + ValueRange{initTensor}) + .result(); Value fakeWindowDims = rewriter.create(loc, kernel, resultETy); @@ -759,7 +770,9 @@ loc, dynamicDims, accTy.getShape(), accETy); Value filledInitTensor = - rewriter.create(loc, initialValue, poolInitTensor) + rewriter + .create(loc, ValueRange{initialValue}, + ValueRange{poolInitTensor}) .result(); Value fakeWindowDims = diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -303,47 +303,6 @@ //===----------------------------------------------------------------------===// // FillOp //===----------------------------------------------------------------------===// -void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, - ArrayRef attrs) { - assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args"); - b.create(block.getArgument(0)); -} - -void FillOp::build(OpBuilder &builder, OperationState &result, Value value, - Value output) { - build(builder, result, output.getType().dyn_cast(), value, - output); - fillStructuredOpRegion( - builder, *result.regions.front(), TypeRange{value.getType()}, - TypeRange{output.getType()}, result.attributes.getAttrs(), {}); -} - -ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type valueType, - Type outputType) { - OpBuilder opBuilder(parser.getContext()); - fillStructuredOpRegion(opBuilder, r, TypeRange{valueType}, - TypeRange{outputType}, {}); - return success(); -} - -/// FillOp region is elided when printing. -void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} - -LogicalResult FillOp::verify() { - OpOperand *output = getOutputOperand(0); - Type fillType = value().getType(); - if (getElementTypeOrSelf(output->get()) != fillType) - return emitOpError("expects fill type to match view elemental type"); - return success(); -} - -void FillOp::getEffects( - SmallVectorImpl> - &effects) { - if (output().getType().isa()) - effects.emplace_back(MemoryEffects::Write::get(), output(), - SideEffects::DefaultResource::get()); -} namespace { @@ -364,7 +323,8 @@ auto newInit = rewriter.create( loc, reshapeOp.getResultType(), oldFill.output(), reshapeOp.reassociation()); - rewriter.replaceOpWithNewOp(reshapeOp, oldFill.value(), newInit); + rewriter.replaceOpWithNewOp(reshapeOp, ValueRange{oldFill.value()}, + ValueRange{newInit}); return success(); } @@ -400,8 +360,8 @@ auto newInitOp = rewriter.create( padOp.getLoc(), reifiedShape.front(), staticShape, oldResultType.getElementType()); - auto newFillOp = - rewriter.create(fillOp.getLoc(), padValue, newInitOp); + auto newFillOp = rewriter.create( + fillOp.getLoc(), ValueRange{padValue}, ValueRange{newInitOp}); rewriter.replaceOpWithNewOp(padOp, oldResultType, newFillOp.result()); @@ -517,10 +477,6 @@ FoldInsertPadIntoFill>(context); } -// TODO: Add the FillOp patterns when transitioning to the OpDSL FillOp. -void FillTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) {} - //===----------------------------------------------------------------------===// // GenericOps //===----------------------------------------------------------------------===// @@ -877,6 +833,11 @@ results.add(context); } +LogicalResult GenericOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} + //===----------------------------------------------------------------------===// // InitTensorOp //===----------------------------------------------------------------------===// @@ -1812,15 +1773,6 @@ } // namespace -#define LINALGOP_FOLDERS(XXX) \ - LogicalResult XXX::fold(ArrayRef, \ - SmallVectorImpl &) { \ - return foldMemRefCast(*this); \ - } - -LINALGOP_FOLDERS(FillOp) -LINALGOP_FOLDERS(GenericOp) - // All named ops canonicalizers and folders are auto-generated in the // .cpp.inc. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -783,8 +783,10 @@ loc, resultShapedType.getShape(), resultShapedType.getElementType()); // Initialize tensor with the pad value - Value tmpTensor = - rewriter.create(loc, padValue, initTensor).result(); + Value tmpTensor = rewriter + .create(loc, ValueRange{padValue}, + ValueRange{initTensor}) + .result(); // Copy original contents into new tensor // Uses linalg.generic, but could be done with tensor.insert_slice diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -334,7 +334,7 @@ } Value mem = rewriter.create(loc, memTp, dynamicSizes); Value zero = constantZero(rewriter, loc, elemTp); - rewriter.create(loc, zero, mem); + rewriter.create(loc, ValueRange{zero}, ValueRange{mem}); return mem; } @@ -749,10 +749,12 @@ // introduces an O(N) operation into the computation, but this reset // operation is amortized over the innermost loops for the access // pattern expansion. - rewriter.create(loc, constantZero(rewriter, loc, eltType), - values); - rewriter.create(loc, constantZero(rewriter, loc, boolType), - filled); + rewriter.create( + loc, ValueRange{constantZero(rewriter, loc, eltType)}, + ValueRange{values}); + rewriter.create( + loc, ValueRange{constantZero(rewriter, loc, boolType)}, + ValueRange{filled}); // Replace expansion op with these buffers and initial index. assert(op.getNumResults() == 4); rewriter.replaceOp(op, {values, filled, indices, zero}); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -466,7 +466,7 @@ Value alloc = rewriter.create(loc, denseTp, args); if (isMaterializing(tensor)) { Value zero = constantZero(rewriter, loc, denseTp.getElementType()); - rewriter.create(loc, zero, alloc); + rewriter.create(loc, ValueRange{zero}, ValueRange{alloc}); } else { Value init = rewriter.create(loc, denseTp, tensor); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -262,7 +262,8 @@ b.create(loc, viewAndIndices); }, [&](OpBuilder &b, Location loc) { - b.create(loc, xferOp.padding(), alloc); + b.create(loc, ValueRange{xferOp.padding()}, + ValueRange{alloc}); // Take partial subview of memref which guarantees no dimension // overflows. IRRewriter rewriter(b); diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -20,22 +20,6 @@ return False -class FillOp: - """Extends the linalg.fill op.""" - - def __init__(self, output: Value, value: Value, *, loc=None, ip=None): - results = [] - if isa(RankedTensorType, output.type): - results = [output.type] - op = self.build_generic( - results=results, - operands=[_get_op_result_or_value(o) for o in [value, output]], - attributes=None, - loc=loc, - ip=ip) - OpView.__init__(self, op) - fill_builtin_region(self.operation) - class InitTensorOp: """Extends the linalg.init_tensor op.""" diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -715,7 +715,7 @@ @linalg_structured_op -def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)): +def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)): """Fills the output tensor with the given value. Works for arbitrary ranked output tensors since the operation performs scalar diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -4,7 +4,7 @@ func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) { // CHECK: [[C0:%.+]] = arith.constant 0 // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 6] - // CHECK: [[FILLED:%.+]] = linalg.fill([[C0]], [[INIT]]) : f32, tensor<1x5x6xf32> -> tensor<1x5x6xf32> + // CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : f32) outs([[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x6xf32>) outs([[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xf32>, tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) return %0 : tensor<1x5x6xf32> @@ -17,7 +17,7 @@ func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) -> (tensor<1x5x6xi32>) { // CHECK: [[C0:%.+]] = arith.constant 0 // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 6] - // CHECK: [[FILLED:%.+]] = linalg.fill([[C0]], [[INIT]]) : i32, tensor<1x5x6xi32> -> tensor<1x5x6xi32> + // CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : i32) outs([[INIT]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32> // CHECK: [[ONE:%.+]] = arith.constant 1 // CHECK: [[TWO:%.+]] = arith.constant 2 // CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32> @@ -33,7 +33,7 @@ // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]] // CHECK: %[[C0_0:.+]] = arith.constant 0 // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM]], 5, 6] - // CHECK: %[[FILLED:.+]] = linalg.fill(%[[C0_0]], %[[INIT]]) : f32, tensor -> tensor + // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0_0]] : f32) outs(%[[INIT]] : tensor) -> tensor // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor, tensor) outs(%[[FILLED]] : tensor) -> tensor %0 = "tosa.matmul"(%arg0, %arg1) : (tensor, tensor) -> (tensor) return %0 : tensor @@ -47,7 +47,7 @@ // CHECK: %[[DIM:.+]] = tensor.dim %arg1, %[[C2]] // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 5, %[[DIM]]] - // CHECK: %[[FILLED:.+]] = linalg.fill(%[[C0]], %[[INIT]]) : f32, tensor<1x5x?xf32> -> tensor<1x5x?xf32> + // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32> // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x?xf32>) outs(%[[FILLED]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32> %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xf32>, tensor<1x3x?xf32>) -> (tensor<1x5x?xf32>) return %0 : tensor<1x5x?xf32> @@ -59,7 +59,7 @@ func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x?x6xf32>) -> (tensor<1x5x6xf32>) { // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 5, 6] - // CHECK: %[[FILLED:.+]] = linalg.fill(%[[C0]], %[[INIT]]) : f32, tensor<1x5x6xf32> -> tensor<1x5x6xf32> + // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x?xf32>, tensor<1x?x6xf32>) outs(%[[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x?xf32>, tensor<1x?x6xf32>) -> (tensor<1x5x6xf32>) return %0 : tensor<1x5x6xf32> @@ -74,7 +74,7 @@ func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) { // CHECK: [[INITT:%.+]] = linalg.init_tensor [5, 6] // CHECK: [[ZERO:%.+]] = arith.constant 0 - // CHECK: [[FILL:%.+]] = linalg.fill([[ZERO]], [[INITT]]) + // CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]] // CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]> // CHECK: [[TRANSPOSE:%.+]] = "tosa.transpose"(%arg1, [[PERM]]) // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6] @@ -97,7 +97,7 @@ func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) { // CHECK: [[INITT:%.+]] = linalg.init_tensor [5, 6] // CHECK: [[ZERO:%.+]] = arith.constant 0 - // CHECK: [[FILL:%.+]] = linalg.fill([[ZERO]], [[INITT]]) + // CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]] // CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]> // CHECK: [[TRANSPOSE:%.+]] = "tosa.transpose"(%arg1, [[PERM]]) // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6] @@ -123,7 +123,7 @@ // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]] // CHECK: %[[INITT:.+]] = linalg.init_tensor [%[[DIM]], 6] // CHECK: %[[ZERO:.+]] = arith.constant 0 - // CHECK: %[[FILL:.+]] = linalg.fill(%[[ZERO]], %[[INITT]]) + // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INITT]] // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> // CHECK: %[[TRANSPOSE:.+]] = "tosa.transpose"(%arg1, %[[PERM]]) // CHECK: %[[INITB:.+]] = linalg.init_tensor [%[[DIM]], 6] @@ -143,7 +143,7 @@ func @max_pool(%arg0: tensor<1x6x34x62xf32>) -> () { // CHECK-DAG: [[CONST:%.+]] = arith.constant -3.40282347E+38 // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 4, 32, 62] - // CHECK-DAG: [[FILL:%.+]] = linalg.fill([[CONST]], [[INIT]]) + // CHECK-DAG: [[FILL:%.+]] = linalg.fill ins([[CONST]]{{.*}}outs([[INIT]] // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [3, 3] // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, [[KERNEL]] : tensor<1x6x34x62xf32>, tensor<3x3xf32>) outs([[FILL]] : tensor<1x4x32x62xf32>) %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>) @@ -157,7 +157,7 @@ // CHECK-DAG: tensor.yield [[CONST]] // CHECK-DAG: [[INITVAL:%.+]] = arith.constant -3.40282347E+38 : f32 // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 4, 33, 62] - // CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INITVAL]], [[INIT]]) + // CHECK-DAG: [[FILL:%.+]] = linalg.fill ins([[INITVAL]]{{.*}}outs([[INIT]] // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [3, 3] // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x6x35x62xf32>, tensor<3x3xf32>) outs([[FILL]] : tensor<1x4x33x62xf32>) %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 1], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x4x33x62xf32>) @@ -170,7 +170,7 @@ // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] // CHECK: %[[CONST:.+]] = arith.constant -3.40282347E+38 // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 4, 32, 62] - // CHECK: %[[FILL:.+]] = linalg.fill(%[[CONST]], %[[INIT]]) + // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CONST]]{{.*}}outs(%[[INIT]] // CHECK: %[[KERNEL:.+]] = linalg.init_tensor [3, 3] // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, %[[KERNEL]] : tensor, tensor<3x3xf32>) outs(%[[FILL]] : tensor) %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor) -> (tensor) @@ -209,7 +209,7 @@ // CHECK: [[PAD:%.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] // CHECK: [[CONST:%.+]] = arith.constant 0 // CHECK: [[POOLINIT:%.+]] = linalg.init_tensor [1, 5, 33, 62] - // CHECK: [[FILL:%.+]] = linalg.fill([[CONST]], [[POOLINIT]]) + // CHECK: [[FILL:%.+]] = linalg.fill ins([[CONST]]{{.*}}outs([[POOLINIT]] // CHECK: [[KERNEL:%.+]] = linalg.init_tensor [4, 4] // CHECK: [[POOL:%.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x5x33x62xf32>) // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 33, 62] @@ -474,7 +474,7 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () { // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 3, 11] // CHECK: [[CST0:%.+]] = arith.constant 0 - // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]]) + // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33] // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>) // CHECK: [[COLLAPSED:%.+]] = "tosa.reshape"([[DEPTH]]) {new_shape = [1, 5, 5, 33]} @@ -520,7 +520,7 @@ func @depthwise_conv_strides(%arg0 : tensor<1x11x9x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () { // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 3, 11] // CHECK: [[CST0:%.+]] = arith.constant 0 - // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]]) + // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33] // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>) // CHECK: [[COLLAPSED:%.+]] = "tosa.reshape"([[DEPTH]]) {new_shape = [1, 5, 5, 33]} @@ -546,7 +546,7 @@ // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 12, 12, 4, 128] // CHECK: [[CST0:%.+]] = arith.constant 0 - // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]]) + // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 12, 12, 512] // CHECK: [[C128:%.+]] = arith.constant -128 // CHECK: [[C42:%.+]] = arith.constant 42 @@ -570,7 +570,7 @@ func @depthwise_conv_quant_dilations(%arg0 : tensor<1x14x14x4xi8>, %arg1 : tensor<3x3x4x128xi8>, %arg2 : tensor<512xi32>) -> () { // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 10, 10, 4, 128] // CHECK: [[CST0:%.+]] = arith.constant 0 - // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]]) + // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 10, 10, 512] // CHECK: [[C128:%.+]] = arith.constant -128 // CHECK: [[C42:%.+]] = arith.constant 42 diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -694,7 +694,7 @@ func @reduce_float(%arg0: tensor<5x4xf32>) -> () { // CHECK: [[INIT:%.+]] = linalg.init_tensor [4] // CHECK: [[CST0:%.+]] = arith.constant 0.0 - // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]]) + // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<4xf32>) // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: [[RES:%.+]] = arith.addf %arg1, %arg2 : f32 @@ -704,7 +704,7 @@ // CHECK: [[INIT:%.+]] = linalg.init_tensor [5] // CHECK: [[CST0:%.+]] = arith.constant 0.0 - // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]]) + // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<5xf32>) // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: [[RES:%.+]] = arith.addf %arg1, %arg2 : f32 @@ -745,7 +745,7 @@ // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[C0]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]], 4] // CHECK: %[[CST0:.+]] = arith.constant 0.0 - // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST0]], %[[INIT]]) + // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST0]]{{.*}}outs(%[[INIT]] // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor) outs(%[[FILL]] : tensor) // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: %[[RES:.+]] = arith.addf %arg1, %arg2 : f32 @@ -767,7 +767,7 @@ // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[C1]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [5, %[[DYN]]] // CHECK: %[[CST1:.+]] = arith.constant 1.0 - // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST1]], %[[INIT]]) + // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST1]]{{.*}}outs(%[[INIT]] // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<5x?x4xf32>) outs(%[[FILL]] : tensor<5x?xf32>) // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: %[[RES:.+]] = arith.mulf %arg1, %arg2 : f32 @@ -789,7 +789,7 @@ // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[C0]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]]] // CHECK: %[[CMIN:.+]] = arith.constant -3.40282347E+38 - // CHECK: %[[FILL:.+]] = linalg.fill(%[[CMIN]], %[[INIT]]) + // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CMIN]]{{.*}}outs(%[[INIT]] // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor) outs(%[[FILL]] : tensor) // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: %[[CMP:.+]] = arith.cmpf ogt, %arg1, %arg2 : f32 @@ -811,7 +811,7 @@ func @reduce_int(%arg0: tensor<5x4xi32>) -> () { // CHECK: [[INIT:%.+]] = linalg.init_tensor [4] // CHECK: [[CST0:%.+]] = arith.constant 0 - // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]]) + // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<4xi32>) // CHECK: ^bb0(%arg1: i32, %arg2: i32) // CHECK: [[RES:%.+]] = arith.addi %arg1, %arg2 : i32 @@ -821,7 +821,7 @@ // CHECK: [[INIT:%.+]] = linalg.init_tensor [5] // CHECK: [[CST0:%.+]] = arith.constant 0 - // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]]) + // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<5xi32>) // CHECK: ^bb0(%arg1: i32, %arg2: i32) // CHECK: [[RES:%.+]] = arith.addi %arg1, %arg2 : i32 @@ -861,7 +861,7 @@ func @reduce_bool(%arg0: tensor<5x4xi1>) -> () { // CHECK: [[INIT:%.+]] = linalg.init_tensor [4] // CHECK: [[CST0:%.+]] = arith.constant true - // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]]) + // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi1>) outs([[FILL]] : tensor<4xi1>) // CHECK: ^bb0(%arg1: i1, %arg2: i1) // CHECK: [[RES:%.+]] = arith.andi %arg1, %arg2 : i1 @@ -889,7 +889,7 @@ // CHECK: [[IDX1:%.+]] = arith.constant 1 : index // CHECK: [[INIT:%.+]] = linalg.init_tensor [11, 1] // CHECK: [[CST:%.+]] = arith.constant 0.0 - // CHECK: [[FILL:%.+]] = linalg.fill([[CST]], [[INIT]]) + // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST]]{{.*}}outs([[INIT]] // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]][0, 0] [5, 1] [1, 1] // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg1 into [[INSERT0]][5, 0] [6, 1] [1, 1] %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>) -> (tensor<11x1xf32>) @@ -901,7 +901,7 @@ // CHECK: [[IDX1:%.+]] = arith.constant 1 : index // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2] // CHECK: [[CST:%.+]] = arith.constant 0.0 - // CHECK: [[FILL:%.+]] = linalg.fill([[CST]], [[INIT]]) + // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST]]{{.*}}outs([[INIT]] // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]][0, 0] [5, 1] [1, 1] // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg0 into [[INSERT0]][0, 1] [5, 1] [1, 1] %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>) -> (tensor<5x2xf32>) @@ -922,7 +922,7 @@ // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[IDX1_2]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [11, %[[DYN]]] // CHECK: %[[CST:.+]] = arith.constant 0.0 - // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST]], %[[INIT]]) + // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]]{{.*}}outs(%[[INIT]] // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %arg0 into %[[FILL]][0, 0] [5, %[[SIZE]]] [1, 1] // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %arg1 into %[[INSERT0]][5, 0] [6, %[[SIZE]]] [1, 1] %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>) -> (tensor<11x?xf32>) @@ -943,7 +943,7 @@ // CHECK: %[[IDX1:.+]] = arith.constant 1 : index // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]], 3] // CHECK: %[[CST:.+]] = arith.constant 0.0 - // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST]], %[[INIT]]) + // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]]{{.*}}outs(%[[INIT]] // CHECK: %[[DYN1:.+]] = tensor.dim %arg0, %[[AXIS]] // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %arg0 into %[[FILL]][0, 0] [%[[DYN1]], 3] [1, 1] // CHECK: %[[SUM:.+]] = arith.addi %[[OFFSET]], %[[DYN1]] @@ -1330,10 +1330,10 @@ func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () { // CHECK: [[IDX_INIT:%.+]] = linalg.init_tensor [2] // CHECK: [[IDX_MIN:%.+]] = arith.constant 0 : i32 - // CHECK: [[IDX_FILL:%.+]] = linalg.fill([[IDX_MIN]], [[IDX_INIT]]) + // CHECK: [[IDX_FILL:%.+]] = linalg.fill ins([[IDX_MIN]]{{.*}}outs([[IDX_INIT]] // CHECK: [[VAL_INIT:%.+]] = linalg.init_tensor [2] // CHECK: [[VAL_MIN:%.+]] = arith.constant -2147483648 - // CHECK: [[VAL_FILL:%.+]] = linalg.fill([[VAL_MIN]], [[VAL_INIT]]) + // CHECK: [[VAL_FILL:%.+]] = linalg.fill ins([[VAL_MIN]]{{.*}}outs([[VAL_INIT]] // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%arg0 : tensor<3x2xi32>) outs([[IDX_FILL]], [[VAL_FILL]] : tensor<2xi32>, tensor<2xi32>) // CHECK: [[IDX:%.+]] = linalg.index 0 // CHECK: [[CAST:%.+]] = arith.index_cast [[IDX]] @@ -1345,10 +1345,10 @@ // CHECK: [[IDX_INIT:%.+]] = linalg.init_tensor [3] // CHECK: [[IDX_MIN:%.+]] = arith.constant 0 : i32 - // CHECK: [[IDX_FILL:%.+]] = linalg.fill([[IDX_MIN]], [[IDX_INIT]]) + // CHECK: [[IDX_FILL:%.+]] = linalg.fill ins([[IDX_MIN]]{{.*}}outs([[IDX_INIT]] // CHECK: [[VAL_INIT:%.+]] = linalg.init_tensor [3] // CHECK: [[VAL_MIN:%.+]] = arith.constant -2147483648 - // CHECK: [[VAL_FILL:%.+]] = linalg.fill([[VAL_MIN]], [[VAL_INIT]]) + // CHECK: [[VAL_FILL:%.+]] = linalg.fill ins([[VAL_MIN]]{{.*}}outs([[VAL_INIT]] // CHECK: linalg.generic {indexing_maps = [#map0, #map2, #map2], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<3x2xi32>) outs([[IDX_FILL]], [[VAL_FILL]] : tensor<3xi32>, tensor<3xi32>) // CHECK: [[IDX:%.+]] = linalg.index 1 // CHECK: [[CAST:%.+]] = arith.index_cast [[IDX]] @@ -1380,10 +1380,10 @@ // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST1]] // CHECK: %[[IDX_INIT:.+]] = linalg.init_tensor [%[[DYN]]] // CHECK: %[[IDX_MIN:.+]] = arith.constant 0 : i32 - // CHECK: %[[IDX_FILL:.+]] = linalg.fill(%[[IDX_MIN]], %[[IDX_INIT]]) + // CHECK: %[[IDX_FILL:.+]] = linalg.fill ins(%[[IDX_MIN]]{{.*}}outs(%[[IDX_INIT]] // CHECK: %[[VAL_INIT:.+]] = linalg.init_tensor [%[[DYN]]] // CHECK: %[[VAL_MIN:.+]] = arith.constant -2147483648 - // CHECK: %[[VAL_FILL:.+]] = linalg.fill(%[[VAL_MIN]], %[[VAL_INIT]]) + // CHECK: %[[VAL_FILL:.+]] = linalg.fill ins(%[[VAL_MIN]]{{.*}}outs(%[[VAL_INIT]] // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%arg0 : tensor<3x?xi32>) outs(%[[IDX_FILL]], %[[VAL_FILL]] : tensor, tensor) // CHECK: %[[IDX:.+]] = linalg.index 0 // CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX]] @@ -1403,10 +1403,10 @@ func @argmax_dyn_axis(%arg0 : tensor<3x?xi32>) -> () { // CHECK: %[[IDX_INIT:.+]] = linalg.init_tensor [3] // CHECK: %[[IDX_MIN:.+]] = arith.constant 0 : i32 - // CHECK: %[[IDX_FILL:.+]] = linalg.fill(%[[IDX_MIN]], %[[IDX_INIT]]) + // CHECK: %[[IDX_FILL:.+]] = linalg.fill ins(%[[IDX_MIN]]{{.*}}outs(%[[IDX_INIT]] // CHECK: %[[VAL_INIT:.+]] = linalg.init_tensor [3] // CHECK: %[[VAL_MIN:.+]] = arith.constant -2147483648 - // CHECK: %[[VAL_FILL:.+]] = linalg.fill(%[[VAL_MIN]], %[[VAL_INIT]]) + // CHECK: %[[VAL_FILL:.+]] = linalg.fill ins(%[[VAL_MIN]]{{.*}}outs(%[[VAL_INIT]] // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<3x?xi32>) outs(%[[IDX_FILL]], %[[VAL_FILL]] : tensor<3xi32>, tensor<3xi32>) // CHECK: %[[IDX:.+]] = linalg.index 1 // CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX]] diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir @@ -148,9 +148,9 @@ // CHECK: %[[m1_casted:.*]] = memref.cast %[[m1]] : memref<10xf32> to memref<10xf32, #[[$MAP3]]> %t1 = linalg.init_tensor [10] : tensor<10xf32> - // CHECK: linalg.fill(%{{.*}}, %[[m1]]) + // CHECK: linalg.fill ins(%{{.*}}{{.*}}outs(%[[m1]] // CHECK: %[[filled_tensor:.*]] = bufferization.to_tensor %[[m1_casted]] - %filled = linalg.fill(%cst, %t1) : f32, tensor<10xf32> -> tensor<10xf32> + %filled = linalg.fill ins(%cst : f32) outs(%t1 : tensor<10xf32>) -> tensor<10xf32> // The transfer_write is out-of-place because "dummy_op" may read. // CHECK: memref.copy %[[m1]], %[[alloc]] diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -167,10 +167,10 @@ func @bufferize_fill(%arg0: tensor) -> tensor { %c0 = arith.constant 0.0 : f32 // CHECK: %[[ALLOC:.*]] = memref.alloc - // CHECK: linalg.fill(%cst, %[[ALLOC]]) : f32, memref + // CHECK: linalg.fill ins(%cst : f32) outs(%[[ALLOC]] : memref) // CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref // CHECK: return %[[TENSOR]] - %0 = linalg.fill(%c0, %arg0) : f32, tensor -> tensor + %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor) -> tensor return %0 : tensor } diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -275,7 +275,7 @@ %c0_i32 = arith.constant 0 : i32 %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 - %0 = linalg.fill(%c0_i32, %arg0) : i32, tensor<7x7xi32> -> tensor<7x7xi32> + %0 = linalg.fill ins(%c0_i32 : i32) outs(%arg0 : tensor<7x7xi32>) -> tensor<7x7xi32> %1 = linalg.matmul ins(%arg1, %arg1: tensor<7x7xf32>, tensor<7x7xf32>) outs(%arg1: tensor<7x7xf32>) -> tensor<7x7xf32> %2 = linalg.generic #trait outs(%arg0 : tensor<7x7xi32>) { @@ -298,7 +298,7 @@ %c21 = arith.constant 21 : index %c42 = arith.constant 42 : index %0 = linalg.init_tensor [%c21, %c42] : tensor - %1 = linalg.fill(%arg1, %0) : f32, tensor -> tensor + %1 = linalg.fill ins(%arg1 : f32) outs(%0 : tensor) -> tensor %2 = tensor.dim %arg0, %c0 : tensor %3 = tensor.dim %arg0, %c1 : tensor %4 = tensor.insert_slice %arg0 into %1[%arg2, %arg3] [%2, %3] [1, 1] : tensor into tensor @@ -306,7 +306,7 @@ } // CHECK-LABEL: func @propogate_casts // CHECK: %[[INIT:.+]] = linalg.init_tensor [21, 42] -// CHECK: %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]]) +// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %{{.+}} into %[[FILL]] // CHECK: %[[RESULT:.+]] = tensor.cast %[[INSERTED]] // CHECK: return %[[RESULT]] @@ -330,8 +330,8 @@ %zero = arith.constant 0.0 : f32 // CHECK: %[[INIT:.+]] = linalg.init_tensor [6, 4] : tensor<6x4xf32> %init = linalg.init_tensor [1, 2, 3, 4] : tensor<1x2x3x4xf32> - // CHECK: %[[FILL:.+]] = linalg.fill(%cst, %[[INIT]]) : f32, tensor<6x4xf32> -> tensor<6x4xf32> - %fill = linalg.fill(%zero, %init) : f32, tensor<1x2x3x4xf32> -> tensor<1x2x3x4xf32> + // CHECK: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor<6x4xf32>) -> tensor<6x4xf32> + %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> %reshape = tensor.collapse_shape %fill [[0, 1, 2], [3]] : tensor<1x2x3x4xf32> into tensor<6x4xf32> // CHECK: return %[[FILL]] : tensor<6x4xf32> @@ -345,8 +345,8 @@ func @fold_fill_reshape_dynamic(%arg0 : tensor) -> tensor { %zero = arith.constant 0.0 : f32 // CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] - %0 = linalg.fill(%zero, %arg0) : f32, tensor -> tensor - // CHECK: %[[RESULT:.+]] = linalg.fill(%{{.+}}, %[[RESHAPE]]) + %0 = linalg.fill ins(%zero : f32) outs(%arg0 : tensor) -> tensor + // CHECK: %[[RESULT:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[RESHAPE]] %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4]] : tensor into tensor // CHECK: return %[[RESULT]] @@ -395,15 +395,15 @@ // CHECK: func @fold_self_copy func @fold_self_copy(%0 : memref<4x16xf32>) { // CHECK-NEXT: return - linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%0 : memref<4x16xf32>) + linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%0 : memref<4x16xf32>) outs(%0 : memref<4x16xf32>) { ^bb0(%arg4: f32, %arg5: f32): linalg.yield %arg4 : f32 } - return + return } // ----- @@ -411,12 +411,12 @@ // CHECK-LABEL: func @fold_static_pad_fill // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[INIT:.+]] = linalg.init_tensor [412, 276] : tensor<412x276xf32> -// CHECK: %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]]) +// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]]{{.*}}outs(%[[INIT]] // CHECK: return %[[FILL]] func @fold_static_pad_fill() -> tensor<412x276xf32> { %f0 = arith.constant 0.0 : f32 %init = linalg.init_tensor [400, 273] : tensor<400x273xf32> - %fill = linalg.fill(%f0, %init) : f32, tensor<400x273xf32> -> tensor<400x273xf32> + %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<400x273xf32>) -> tensor<400x273xf32> %pad = tensor.pad %fill low[4, 1] high[8, 2] { ^bb0(%arg1: index, %arg2: index): tensor.yield %f0 : f32 @@ -436,18 +436,18 @@ // CHECK-DAG: %[[I1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[OF:.+]] = linalg.fill(%[[F0]], %[[SRC]]) : f32, tensor<8x?x16x32xf32> +// CHECK: %[[OF:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[SRC]] : tensor<8x?x16x32xf32>) // CHECK: %[[S0:.+]] = affine.apply #[[MAP0]]()[%[[LOW0]]] // CHECK: %[[DIM1:.+]] = tensor.dim %[[OF]], %[[I1]] : tensor<8x?x16x32xf32> // CHECK: %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]] // CHECK: %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]] // CHECK: %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[S0]], %[[S1]], %[[S2]], %[[S3]]] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]]) +// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]]{{.*}}outs(%[[INIT]] // CHECK: return %[[FILL]] func @fold_dynamic_pad_fill(%init: tensor<8x?x16x32xf32>, %low0: index, %low3: index, %high2: index, %high3: index) -> tensor { %f0 = arith.constant 0.0 : f32 - %fill = linalg.fill(%f0, %init) : f32, tensor<8x?x16x32xf32> -> tensor<8x?x16x32xf32> + %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x?x16x32xf32>) -> tensor<8x?x16x32xf32> %pad = tensor.pad %fill low[%low0, 8, 7, %low3] high[1, 2, %high2, %high3] { ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): tensor.yield %f0 : f32 @@ -462,7 +462,7 @@ %f0 = arith.constant 0.0 : f32 %f1 = arith.constant 1.0 : f32 %init = linalg.init_tensor [400, 273] : tensor<400x273xf32> - %fill = linalg.fill(%f0, %init) : f32, tensor<400x273xf32> -> tensor<400x273xf32> + %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<400x273xf32>) -> tensor<400x273xf32> // CHECK: tensor.pad %pad = tensor.pad %fill low[4, 1] high[8, 2] { ^bb0(%arg1: index, %arg2: index): @@ -635,7 +635,7 @@ // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[INIT:.+]] = linalg.init_tensor [8, 384, 384] -// CHECK: %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]]) +// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]]{{.*}}outs(%[[INIT]] // CHECK: %[[OFFSET1:.+]] = affine.apply #[[$MAP]]()[%[[LOW1]]] // CHECK: %[[D0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor // CHECK: %[[D1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor @@ -649,7 +649,7 @@ tensor.yield %f0 : f32 } : tensor to tensor<8x128x128xf32> %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> - %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32> + %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> %0 = tensor.insert_slice %pad into %fill[0, 1, 2] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> return %0: tensor<8x384x384xf32> } @@ -670,7 +670,7 @@ tensor.yield %f0 : f32 } : tensor<7x123x124xf32> to tensor<8x128x128xf32> %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> - %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32> + %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> %0 = tensor.insert_slice %a into %fill[%offset, 0, 0] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> %1 = tensor.insert_slice %a into %0 [0, 128, %offset][8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> @@ -689,7 +689,7 @@ tensor.yield %f0 : f32 } : tensor<7x123x124xf32> to tensor<8x128x128xf32> %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> - %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32> + %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> %0 = tensor.insert_slice %a into %fill[%offset, 0, 0] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> %1 = tensor.insert_slice %a into %0 [0, 0, 129] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> // Range overlap with %1 at dim#3 @@ -709,7 +709,7 @@ tensor.yield %f0 : f32 } : tensor<7x123x124xf32> to tensor<8x128x128xf32> %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> - %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32> + %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> %0 = tensor.insert_slice %a into %fill[0, 0, %offset] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> %1 = tensor.insert_slice %a into %0 [0, 128, 255] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> // Range overlap with %0 at dim#3 @@ -729,7 +729,7 @@ tensor.yield %f0 : f32 } : tensor<7x123x124xf32> to tensor<8x128x128xf32> %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> - %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32> + %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> // Overlap btween %0 and %1 is fine but not with %2 is fine. // CHECK-COUNT-3: tensor.insert_slice %0 = tensor.insert_slice %a into %fill[0, 0, %offset] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> @@ -752,7 +752,7 @@ } : tensor<7x123x124xf32> to tensor<8x128x128xf32> %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> // Different filling value than padding value. - %fill = linalg.fill(%f1, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32> + %fill = linalg.fill ins(%f1 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> %0 = tensor.insert_slice %a into %fill[%offset, 0, 0] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> %1 = tensor.insert_slice %a into %0 [0, 128, %offset][8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> diff --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir --- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir +++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir @@ -68,7 +68,7 @@ // CHECK-FUSE: %[[CST:.*]] = arith.constant dense<0.000000e+00> // CHECK-FUSE: vector.transfer_write %[[CST]] %cst = arith.constant 0.0 : f32 - %0 = linalg.fill(%cst, %arg0) : f32, tensor<72x72xf32> -> tensor<72x72xf32> + %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<72x72xf32>) -> tensor<72x72xf32> // Check the matmul is padded and vectorized despite the empty anchor op string. // CHECK-FUSE: vector.outerproduct @@ -81,7 +81,7 @@ // CHECK-DECOMP: func @conv( func @conv(%arg0: tensor<8x18x17x32xf32>, %arg1: tensor<3x3x32x64xf32>, %arg2: tensor<8x16x15x64xf32>) -> tensor<8x16x15x64xf32> { %cst = arith.constant 0.000000e+00 : f32 - %0 = linalg.fill(%cst, %arg2) : f32, tensor<8x16x15x64xf32> -> tensor<8x16x15x64xf32> + %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<8x16x15x64xf32>) -> tensor<8x16x15x64xf32> // Check the conv is padded by a rank-reducing vector transfer op pair. // CHECK-DECOMP: vector.transfer_read {{.*}}: tensor<1x1x?x8xf32>, vector<1x8x8xf32> diff --git a/mlir/test/Dialect/Linalg/comprehensive-bufferize-analysis-2fill-extract-matmul-all-perms.mlir b/mlir/test/Dialect/Linalg/comprehensive-bufferize-analysis-2fill-extract-matmul-all-perms.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-bufferize-analysis-2fill-extract-matmul-all-perms.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-bufferize-analysis-2fill-extract-matmul-all-perms.mlir @@ -18,9 +18,9 @@ %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "false"]} - %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} @@ -45,9 +45,9 @@ %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "false"]} - %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} @@ -71,11 +71,11 @@ %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "false"]} - %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} @@ -97,13 +97,13 @@ %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "false"]} - %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} %5 = linalg.matmul ins(%3, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> return %5 : tensor<256x256xf32> @@ -123,11 +123,11 @@ %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "false"]} - %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} @@ -149,13 +149,13 @@ %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "false"]} - %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} %5 = linalg.matmul ins(%3, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> return %5 : tensor<256x256xf32> @@ -176,9 +176,9 @@ %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "false"]} - %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} @@ -203,9 +203,9 @@ %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "false"]} - %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} @@ -230,11 +230,11 @@ %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "false"]} - %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32> + %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} @@ -257,13 +257,13 @@ %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "false"]} - %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32> + %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} %5 = linalg.matmul ins(%1, %4 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> return %5 : tensor<256x256xf32> @@ -284,11 +284,11 @@ %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "false"]} - %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} @@ -311,13 +311,13 @@ %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "false"]} - %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32> + %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} %5 = linalg.matmul ins(%1, %4 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> return %5 : tensor<256x256xf32> @@ -340,9 +340,9 @@ // CHECK: {__inplace_operands_attr__ = ["false"]} %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32> + %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} @@ -367,11 +367,11 @@ // CHECK: {__inplace_operands_attr__ = ["false"]} %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32> + %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} %5 = linalg.matmul ins(%1, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> return %5 : tensor<256x256xf32> @@ -394,9 +394,9 @@ // CHECK: {__inplace_operands_attr__ = ["false"]} %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32> + %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} @@ -421,11 +421,11 @@ // CHECK: {__inplace_operands_attr__ = ["false"]} %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32> + %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} %5 = linalg.matmul ins(%1, %4 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> return %5 : tensor<256x256xf32> @@ -450,9 +450,9 @@ // CHECK: {__inplace_operands_attr__ = ["true"]} %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32> + %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} %5 = linalg.matmul ins(%1, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> return %5 : tensor<256x256xf32> @@ -477,9 +477,9 @@ // CHECK: {__inplace_operands_attr__ = ["true"]} %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32> + %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} %5 = linalg.matmul ins(%1, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> return %5 : tensor<256x256xf32> @@ -502,9 +502,9 @@ // CHECK: {__inplace_operands_attr__ = ["false"]} %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} @@ -529,11 +529,11 @@ // CHECK: {__inplace_operands_attr__ = ["false"]} %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} %5 = linalg.matmul ins(%3, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> return %5 : tensor<256x256xf32> @@ -556,9 +556,9 @@ // CHECK: {__inplace_operands_attr__ = ["false"]} %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} @@ -583,11 +583,11 @@ // CHECK: {__inplace_operands_attr__ = ["false"]} %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true"]} %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32> + %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} %5 = linalg.matmul ins(%1, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> return %5 : tensor<256x256xf32> @@ -612,9 +612,9 @@ // CHECK: {__inplace_operands_attr__ = ["true"]} %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32> + %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} %5 = linalg.matmul ins(%1, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> return %5 : tensor<256x256xf32> @@ -639,9 +639,9 @@ // CHECK: {__inplace_operands_attr__ = ["true"]} %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32> + %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32> // CHECK: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32> + %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32> // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]} %5 = linalg.matmul ins(%1, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> return %5 : tensor<256x256xf32> diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir @@ -11,8 +11,8 @@ // CHECK-NEXT: %[[C0:.*]] = arith.constant 0{{.*}} : f32 %v0 = arith.constant 0.0 : f32 - // CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref - %d = linalg.fill(%v0, %c) : f32, tensor -> tensor + // CHECK-NEXT: linalg.fill ins(%[[C0]] : f32) outs(%[[C]] : memref) + %d = linalg.fill ins(%v0 : f32) outs(%c : tensor) -> tensor // CHECK-NEXT: linalg.dot ins(%[[A]], %[[B]] : memref<64xf32, #[[$DYN_1D_MAP]]>, memref<64xf32, #[[$DYN_1D_MAP]]>) outs(%[[C]] : memref) %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>) @@ -41,12 +41,12 @@ %B = linalg.init_tensor [64] : tensor<64xf32> %C = linalg.init_tensor [] : tensor - // CHECK-NEXT: linalg.fill(%[[C1]], %[[A]]) : f32, memref<64xf32> - // CHECK-NEXT: linalg.fill(%[[C2]], %[[B]]) : f32, memref<64xf32> - // CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref - %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32> - %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32> - %CC = linalg.fill(%v0, %C) : f32, tensor -> tensor + // CHECK-NEXT: linalg.fill ins(%[[C1]] : f32) outs(%[[A]] : memref<64xf32>) + // CHECK-NEXT: linalg.fill ins(%[[C2]] : f32) outs(%[[B]] : memref<64xf32>) + // CHECK-NEXT: linalg.fill ins(%[[C0]] : f32) outs(%[[C]] : memref) + %AA = linalg.fill ins(%v1 : f32) outs(%A : tensor<64xf32>) -> tensor<64xf32> + %BB = linalg.fill ins(%v2 : f32) outs(%B : tensor<64xf32>) -> tensor<64xf32> + %CC = linalg.fill ins(%v0 : f32) outs(%C : tensor) -> tensor // CHECK-NEXT: call @init_and_dot(%[[cA]], %[[cB]], %[[cC]]) %res = call @init_and_dot(%AA, %BB, %CC) : diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir @@ -16,7 +16,7 @@ // CHECK: linalg.fill // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"] - %1 = linalg.fill(%cst, %0) : f32, tensor -> tensor + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor // CHECK: tensor.insert_slice // CHECK-SAME: {__inplace_operands_attr__ = ["true", "false", "none"] @@ -43,7 +43,7 @@ // CHECK: linalg.fill // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"] - %1 = linalg.fill(%cst, %0) : f32, tensor -> tensor + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor // CHECK: tensor.insert_slice // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none"] diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -257,7 +257,7 @@ // CHECK: linalg.fill // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %0) : f32, tensor -> tensor + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor // CHECK: tensor.insert_slice // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none", "none"]} @@ -289,7 +289,7 @@ // CHECK: linalg.fill // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]} - %1 = linalg.fill(%cst, %0) : f32, tensor -> tensor + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor // CHECK: tensor.insert_slice // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none", "none"]} @@ -301,7 +301,7 @@ // CHECK: linalg.fill // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]} - %5 = linalg.fill(%cst, %4) : f32, tensor -> tensor + %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor) -> tensor %3 = vector.transfer_read %1[%idx2], %cst2 : tensor, vector<5xf32> @@ -501,7 +501,7 @@ // CHECK-SAME: {__inplace_operands_attr__ = ["true", "false", "none", "none"]} %sA = tensor.extract_slice %A[0, 0][%idx, %idx][1, 1] : tensor to tensor %ssA = tensor.extract_slice %sA[0, 0][4, 4][1, 1] : tensor to tensor<4x4xf32> - %FA = linalg.fill(%f0, %ssA) : f32, tensor<4x4xf32> -> tensor<4x4xf32> + %FA = linalg.fill ins(%f0 : f32) outs(%ssA : tensor<4x4xf32>) -> tensor<4x4xf32> %rsA = tensor.insert_slice %FA into %sA[0, 0][4, 4][1, 1] : tensor<4x4xf32> into tensor %rA = tensor.insert_slice %rsA into %A[0, 0][%idx, %idx][1, 1] : tensor into tensor @@ -524,7 +524,7 @@ %sB = tensor.extract_slice %B[0, 0][%idx, %idx][1, 1] : tensor to tensor %ssB = tensor.extract_slice %sB[0, 0][4, %idx][1, 1] : tensor to tensor<4x?xf32> %sssB = tensor.extract_slice %ssB[0, 0][4, 4][1, 1] : tensor<4x?xf32> to tensor<4x4xf32> - %FB = linalg.fill(%f0, %sssB) : f32, tensor<4x4xf32> -> tensor<4x4xf32> + %FB = linalg.fill ins(%f0 : f32) outs(%sssB : tensor<4x4xf32>) -> tensor<4x4xf32> %rssB = tensor.insert_slice %FB into %ssB[0, 0][4, 4][1, 1] : tensor<4x4xf32> into tensor<4x?xf32> %rsB = tensor.insert_slice %rssB into %sB[0, 0][4, %idx][1, 1] : tensor<4x?xf32> into tensor %rB = tensor.insert_slice %rsB into %B[0, 0][%idx, %idx][1, 1] : tensor into tensor @@ -547,7 +547,7 @@ // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none", "none"]} %sC = tensor.extract_slice %C[0, 0][%idx, %idx][1, 1] : tensor to tensor %ssC = tensor.extract_slice %sC[0, 0][%sz1, 4][1, 1] : tensor to tensor - %FC = linalg.fill(%f0, %ssC) : f32, tensor -> tensor + %FC = linalg.fill ins(%f0 : f32) outs(%ssC : tensor) -> tensor %rsC = tensor.insert_slice %FC into %sC[0, 0][%sz2, 4][1, 1] : tensor into tensor %rC = tensor.insert_slice %rsC into %C[0, 0][%idx, %idx][1, 1] : tensor into tensor @@ -689,12 +689,12 @@ // cannot bufferize inplace. // CHECK: fill // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false"]} - %A = linalg.fill(%f1, %I) : f32, tensor<64xf32> -> tensor<64xf32> + %A = linalg.fill ins(%f1 : f32) outs(%I : tensor<64xf32>) -> tensor<64xf32> // 1. Bufferizes inplace: no alias to %A is yet possible. // CHECK: fill // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]} - %B = linalg.fill(%f2, %I) : f32, tensor<64xf32> -> tensor<64xf32> + %B = linalg.fill ins(%f2 : f32) outs(%I : tensor<64xf32>) -> tensor<64xf32> call @foo(%A) : (tensor<64xf32>) -> () call @foo(%B) : (tensor<64xf32>) -> () @@ -725,12 +725,12 @@ // bufferize inplace. // CHECK: fill // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false"]} - %A = linalg.fill(%f1, %I) : f32, tensor<64xf32> -> tensor<64xf32> + %A = linalg.fill ins(%f1 : f32) outs(%I : tensor<64xf32>) -> tensor<64xf32> // 4. Bufferizes inplace: no alias to %A is yet possible. // CHECK: fill // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]} - %B = linalg.fill(%f2, %I) : f32, tensor<64xf32> -> tensor<64xf32> + %B = linalg.fill ins(%f2 : f32) outs(%I : tensor<64xf32>) -> tensor<64xf32> // 3. Does not read or write, bufferizes inplace. // CHECK: scf.for @@ -750,12 +750,12 @@ // cannot bufferize inplace. // CHECK: fill // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false"]} - %A2 = linalg.fill(%f1, %I2) : f32, tensor<64xf32> -> tensor<64xf32> + %A2 = linalg.fill ins(%f1 : f32) outs(%I2 : tensor<64xf32>) -> tensor<64xf32> // 1. Bufferizes inplace: no alias to %A2 is yet possible. // CHECK: fill // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]} - %B2 = linalg.fill(%f2, %I2) : f32, tensor<64xf32> -> tensor<64xf32> + %B2 = linalg.fill ins(%f2 : f32) outs(%I2 : tensor<64xf32>) -> tensor<64xf32> call @bar(%A2) : (tensor<64xf32>) -> () call @bar(%B2) : (tensor<64xf32>) -> () @@ -800,8 +800,8 @@ // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false"]} // CHECK: linalg.fill // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]} - %8 = linalg.fill(%cst_0, %7) : f32, tensor<256x256xf32> -> tensor<256x256xf32> - %11 = linalg.fill(%cst_1, %7) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %8 = linalg.fill ins(%cst_0 : f32) outs(%7 : tensor<256x256xf32>) -> tensor<256x256xf32> + %11 = linalg.fill ins(%cst_1 : f32) outs(%7 : tensor<256x256xf32>) -> tensor<256x256xf32> // CHECK: tensor.extract_slice // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} @@ -838,7 +838,7 @@ // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false"]} // CHECK: vector.transfer_write // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "none", "none"] - %8 = linalg.fill(%cst_0, %7) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %8 = linalg.fill ins(%cst_0 : f32) outs(%7 : tensor<256x256xf32>) -> tensor<256x256xf32> %9 = vector.transfer_read %arg0[%c0, %c0], %cst_0 {in_bounds = [false, true]} : tensor<518x518xf32>, vector<256x256xf32> %10 = vector.transfer_write %9, %8[%c0, %c0] {in_bounds = [true, true]} : vector<256x256xf32>, tensor<256x256xf32> @@ -846,7 +846,7 @@ // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]} // CHECK: vector.transfer_write // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "none", "none"] - %11 = linalg.fill(%cst_1, %7) : f32, tensor<256x256xf32> -> tensor<256x256xf32> + %11 = linalg.fill ins(%cst_1 : f32) outs(%7 : tensor<256x256xf32>) -> tensor<256x256xf32> %12 = vector.transfer_read %arg1[%c0, %c0], %cst_0 {in_bounds = [false, true]} : tensor<518x518xf32>, vector<256x256xf32> %13 = vector.transfer_write %12, %11[%c0, %c0] {in_bounds = [true, true]} : vector<256x256xf32>, tensor<256x256xf32> @@ -891,7 +891,7 @@ // CHECK: linalg.fill // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"] - %0 = linalg.fill(%cst, %arg2) : f32, tensor<62x90xf32> -> tensor<62x90xf32> + %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<62x90xf32>) -> tensor<62x90xf32> // CHECK: tensor.extract_slice // CHECK-SAME: {__inplace_operands_attr__ = ["true"] diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir @@ -23,8 +23,8 @@ // CHECK: %[[T_SUBVIEW:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] %a = linalg.init_tensor[%sz] : tensor - // CHECK: linalg.fill({{.*}}, %[[EXTRACT_SLICE_ALLOC]]) : f32, memref - %f = linalg.fill(%f0, %a) : f32, tensor -> tensor + // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[EXTRACT_SLICE_ALLOC]] : memref) + %f = linalg.fill ins(%f0 : f32) outs(%a : tensor) -> tensor // CHECK: memref.copy %[[FUNC_ARG]], %[[ALLOC]] : memref to memref // CHECK: %[[SV0_ALLOC:.*]] = memref.subview %[[ALLOC]][0] [%[[sz]]] [1] : memref to memref @@ -54,8 +54,8 @@ // CHECK: %[[T_SUBVIEW:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] %a = linalg.init_tensor[%sz] : tensor - // CHECK: linalg.fill({{.*}}, %[[T_SUBVIEW]]) : f32, memref -> tensor + // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW]] : memref) -> tensor // Self-copy canonicalizes away later. %r1 = tensor.insert_slice %f into %t[42][%sz][1]: tensor into tensor @@ -81,8 +81,8 @@ %iv_i32 = arith.index_cast %iv : index to i32 %f = arith.sitofp %iv_i32 : i32 to f32 - // CHECK: linalg.fill(%{{.*}}, %[[subview]]) - %filled = linalg.fill(%f, %blank) : f32, tensor<5xf32> -> tensor<5xf32> + // CHECK: linalg.fill ins(%{{.*}}{{.*}}outs(%[[subview]] + %filled = linalg.fill ins(%f : f32) outs(%blank : tensor<5xf32>) -> tensor<5xf32> // CHECK-NOT: memref.copy %inserted = tensor.insert_slice %filled into %bb[%iv][5][1] : tensor<5xf32> into tensor @@ -111,8 +111,8 @@ %iv_i32 = arith.index_cast %iv : index to i32 %f = arith.sitofp %iv_i32 : i32 to f32 - // CHECK: linalg.fill(%{{.*}}, %[[subview]]) - %filled = linalg.fill(%f, %blank) : f32, tensor<5xf32> -> tensor<5xf32> + // CHECK: linalg.fill ins(%{{.*}}{{.*}}outs(%[[subview]] + %filled = linalg.fill ins(%f : f32) outs(%blank : tensor<5xf32>) -> tensor<5xf32> // CHECK-NOT: memref.copy %inserted = tensor.insert_slice %filled into %bb[%idx][5][1] : tensor<5xf32> into tensor diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -173,7 +173,7 @@ func @mini_test_case1() -> tensor<10x20xf32> { %f0 = arith.constant 0.0 : f32 %t = linalg.init_tensor [10, 20] : tensor<10x20xf32> - %r = linalg.fill(%f0, %t) : f32, tensor<10x20xf32> -> tensor<10x20xf32> + %r = linalg.fill ins(%f0 : f32) outs(%t : tensor<10x20xf32>) -> tensor<10x20xf32> // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}} return %r : tensor<10x20xf32> } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -40,8 +40,8 @@ /// Inplaceable, no alloc // CHECK-NOT: alloc - // CHECK: linalg.fill(%[[F0]], %[[A]]) : f32, memref - %r = linalg.fill(%f0, %A) : f32, tensor -> tensor + // CHECK: linalg.fill ins(%[[F0]] : f32) outs(%[[A]] : memref) + %r = linalg.fill ins(%f0 : f32) outs(%A : tensor) -> tensor // CHECK: return // CHECK-NOT: tensor @@ -78,8 +78,8 @@ // CHECK: %[[D0:.*]] = memref.dim %[[A]], {{.*}} : memref // CHECK: %[[ALLOC:.*]] = memref.alloc(%[[D0]]) {alignment = 128 : i64} : memref - // CHECK: linalg.fill(%[[F0]], %[[ALLOC]]) : f32, memref - %r = linalg.fill(%f0, %A) : f32, tensor -> tensor + // CHECK: linalg.fill ins(%[[F0]] : f32) outs(%[[ALLOC]] : memref) + %r = linalg.fill ins(%f0 : f32) outs(%A : tensor) -> tensor // CHECK: dealloc %[[ALLOC]] : memref // CHECK: return %[[ALLOC]] : memref @@ -101,8 +101,8 @@ /// Cross-op multiple uses of %A, the first op which has interfering reads must alloc. // CHECK: %[[ALLOC:.*]] = memref.alloc - // CHECK: linalg.fill({{.*}}, %[[ALLOC]] - %f = linalg.fill(%f0, %A) : f32, tensor -> tensor + // CHECK: linalg.fill ins({{.*}}{{.*}}outs(%[[ALLOC]] + %f = linalg.fill ins(%f0 : f32) outs(%A : tensor) -> tensor /// The second op has no interfering reads and can reuse. // CHECK-NOT: alloc @@ -240,8 +240,8 @@ %r0 = tensor.insert_slice %t into %A[0][4][1] : tensor<4xf32> into tensor /// Overwrite A inplace. - // CHECK: linalg.fill({{.*}}, %[[A]] - %r1 = linalg.fill(%f0, %r0) : f32, tensor -> tensor + // CHECK: linalg.fill ins({{.*}}{{.*}}outs(%[[A]] + %r1 = linalg.fill ins(%f0 : f32) outs(%r0 : tensor) -> tensor // CHECK: return // CHECK-NOT: tensor @@ -262,8 +262,8 @@ { %f0 = arith.constant 0.0 : f32 - // CHECK: linalg.fill({{.*}}, %[[A]] - %r0 = linalg.fill(%f0, %A) : f32, tensor -> tensor + // CHECK: linalg.fill ins({{.*}}{{.*}}outs(%[[A]] + %r0 = linalg.fill ins(%f0 : f32) outs(%A : tensor) -> tensor // CHECK-NOT: alloc // CHECK: %[[SV_A:.*]] = memref.subview %[[A]] @@ -581,8 +581,8 @@ // CHECK-NEXT: %[[C0:.*]] = arith.constant 0{{.*}} : f32 %v0 = arith.constant 0.0 : f32 - // CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref - %d = linalg.fill(%v0, %c) : f32, tensor -> tensor + // CHECK-NEXT: linalg.fill ins(%[[C0]] : f32) outs(%[[C]] : memref) + %d = linalg.fill ins(%v0 : f32) outs(%c : tensor) -> tensor // CHECK-NEXT: linalg.dot ins(%[[A]], %[[B]] : memref<64xf32, #[[$DYN_1D_MAP]]>, memref<64xf32, #[[$DYN_1D_MAP]]>) outs(%[[C]] : memref) %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>) @@ -611,12 +611,12 @@ %B = linalg.init_tensor [64] : tensor<64xf32> %C = linalg.init_tensor [] : tensor - // CHECK-NEXT: linalg.fill(%[[C1]], %[[A]]) : f32, memref<64xf32> - // CHECK-NEXT: linalg.fill(%[[C2]], %[[B]]) : f32, memref<64xf32> - // CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref - %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32> - %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32> - %CC = linalg.fill(%v0, %C) : f32, tensor -> tensor + // CHECK-NEXT: linalg.fill ins(%[[C1]] : f32) outs(%[[A]] : memref<64xf32>) + // CHECK-NEXT: linalg.fill ins(%[[C2]] : f32) outs(%[[B]] : memref<64xf32>) + // CHECK-NEXT: linalg.fill ins(%[[C0]] : f32) outs(%[[C]] : memref) + %AA = linalg.fill ins(%v1 : f32) outs(%A : tensor<64xf32>) -> tensor<64xf32> + %BB = linalg.fill ins(%v2 : f32) outs(%B : tensor<64xf32>) -> tensor<64xf32> + %CC = linalg.fill ins(%v0 : f32) outs(%C : tensor) -> tensor // CHECK-NEXT: call @init_and_dot(%[[cA]], %[[cB]], %[[cC]]) %res = call @init_and_dot(%AA, %BB, %CC) : @@ -730,8 +730,8 @@ tensor<128x192xf32> to tensor<8x16xf32> // linalg.fill is inplace. - // CHECK: linalg.fill(%{{.*}}, %[[ALLOC]]) : f32, memref<8x16xf32> - %5 = linalg.fill(%cst, %4) : f32, tensor<8x16xf32> -> tensor<8x16xf32> + // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%[[ALLOC]] : memref<8x16xf32>) + %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK: scf.for %[[K:.*]] = %6 = scf.for %arg7 = %c0 to %c256 step %c32 iter_args(%arg8 = %5) -> (tensor<8x16xf32>) { @@ -800,7 +800,7 @@ %sA = tensor.extract_slice %A[0, 0][%idx, %idx][1, 1] : tensor to tensor %ssA = tensor.extract_slice %sA[0, 0][4, 4][1, 1] : tensor to tensor<4x4xf32> - %FA = linalg.fill(%f0, %ssA) : f32, tensor<4x4xf32> -> tensor<4x4xf32> + %FA = linalg.fill ins(%f0 : f32) outs(%ssA : tensor<4x4xf32>) -> tensor<4x4xf32> %rsA = tensor.insert_slice %FA into %sA[0, 0][4, 4][1, 1] : tensor<4x4xf32> into tensor %rA = tensor.insert_slice %rsA into %A[0, 0][%idx, %idx][1, 1] : tensor into tensor diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -156,7 +156,7 @@ %0 = linalg.generic #trait ins(%arg0 : tensor<1x5xf32>) outs(%shape : tensor<5xf32>) { - ^bb0(%arg2: f32, %arg3: f32): + ^bb0(%arg2: f32, %arg3: f32): linalg.yield %arg2 : f32 } -> tensor<5xf32> return %0 : tensor<5xf32> @@ -250,7 +250,7 @@ %2 = linalg.generic {i64, indexing_maps = [#map1, #map0], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<5xf32>) outs(%1 : tensor<1x2x5xf32>) { - ^bb0(%arg1: f32, %arg2: f32): + ^bb0(%arg1: f32, %arg2: f32): linalg.yield %arg1 : f32 } -> tensor<1x2x5xf32> %3 = tensor.collapse_shape %2 [[0, 1], [2]] @@ -266,7 +266,7 @@ func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32> { %cst = arith.constant 0.0 : f32 %init = linalg.init_tensor [1] : tensor<1xf32> - %fill = linalg.fill(%cst, %init) : f32, tensor<1xf32> -> tensor<1xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1xf32>) -> tensor<1xf32> %add = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} @@ -287,7 +287,7 @@ // CHECK: %[[INPUT_RESHAPE:.+]] = tensor.collapse_shape %{{.+}} {{\[}}[0, 1]] : tensor<1x1000xf32> into tensor<1000xf32> // CHECK: %[[INIT:.+]] = linalg.init_tensor [] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill(%cst, %[[INIT]]) : f32, tensor -> tensor +// CHECK: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor) -> tensor // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["reduction"] @@ -331,14 +331,14 @@ %c3 = arith.constant 3 : index %0 = tensor.dim %arg0, %c3 : tensor<1x?x1x?xf32> %1 = linalg.init_tensor [1, %0] : tensor<1x?xf32> - %2 = linalg.fill(%cst, %1) : f32, tensor<1x?xf32> -> tensor<1x?xf32> + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x?xf32>) -> tensor<1x?xf32> %3 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%arg0 : tensor<1x?x1x?xf32>) outs(%2 : tensor<1x?xf32>) { - ^bb0(%arg1: f32, %arg2: f32): + ^bb0(%arg1: f32, %arg2: f32): %4 = arith.addf %arg1, %arg2 : f32 linalg.yield %4 : f32 } -> tensor<1x?xf32> @@ -350,7 +350,7 @@ // CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x?xf32> // CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]]) +// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] // CHECK: %[[RESULT:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] // CHECK-SAME: iterator_types = ["parallel", "reduction"] @@ -365,14 +365,14 @@ %cst = arith.constant 1.000000e+00 : f32 %c3 = arith.constant 3 : index %1 = linalg.init_tensor [1, 1] : tensor<1x1xf32> - %2 = linalg.fill(%cst, %1) : f32, tensor<1x1xf32> -> tensor<1x1xf32> + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32> %3 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%arg0 : tensor<1x?x1x1xf32>) outs(%2 : tensor<1x1xf32>) { - ^bb0(%arg1: f32, %arg2: f32): + ^bb0(%arg1: f32, %arg2: f32): %4 = arith.addf %arg1, %arg2 : f32 linalg.yield %4 : f32 } -> tensor<1x1xf32> @@ -383,7 +383,7 @@ // CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x1xf32> // CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3] // CHECK: %[[INIT:.+]] = linalg.init_tensor [1] : tensor<1xf32> -// CHECK: %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]]) +// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] // CHECK: %[[RESULT:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["parallel"] @@ -399,14 +399,14 @@ %c2 = arith.constant 2 : index %0 = tensor.dim %arg0, %c2 : tensor %1 = linalg.init_tensor [%0, 1] : tensor - %2 = linalg.fill(%cst, %1) : f32, tensor -> tensor + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor %3 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%arg0 : tensor) outs(%2 : tensor) { - ^bb0(%arg1: f32, %arg2: f32): + ^bb0(%arg1: f32, %arg2: f32): %4 = arith.addf %arg1, %arg2 : f32 linalg.yield %4 : f32 } -> tensor @@ -418,7 +418,7 @@ // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]]) +// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] // CHECK: %[[RESULT:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] // CHECK-SAME: iterator_types = ["parallel", "reduction"] @@ -608,7 +608,7 @@ linalg.generic #trait ins(%arg0 : memref<1x5xf32>) outs(%shape : memref<5xf32>) { - ^bb0(%arg2: f32, %arg3: f32): + ^bb0(%arg2: f32, %arg3: f32): linalg.yield %arg2 : f32 } return %shape : memref<5xf32> @@ -702,7 +702,7 @@ linalg.generic {i64, indexing_maps = [#map1, #map0], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : memref<5xf32>) outs(%1 : memref<1x2x5xf32>) { - ^bb0(%arg1: f32, %arg2: f32): + ^bb0(%arg1: f32, %arg2: f32): linalg.yield %arg1 : f32 } %3 = memref.collapse_shape %1 [[0, 1], [2]] @@ -792,7 +792,7 @@ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : memref, f32) // CHECK-SAME: outs(%[[OUT]] : memref) { -// CHECK: ^bb0(%{{.*}}: f32, %[[ARG:.*]]: f32, %{{.*}}: f32): +// CHECK: ^bb0(%{{.*}}: f32, %[[ARG:.*]]: f32, %{{.*}}: f32): // CHECK: linalg.yield %[[ARG]] : f32 // CHECK: } // CHECK: return %[[ARG2]] : memref diff --git a/mlir/test/Dialect/Linalg/forward-vector-transfers.mlir b/mlir/test/Dialect/Linalg/forward-vector-transfers.mlir --- a/mlir/test/Dialect/Linalg/forward-vector-transfers.mlir +++ b/mlir/test/Dialect/Linalg/forward-vector-transfers.mlir @@ -29,7 +29,7 @@ %c0 = arith.constant 0: index %f0 = arith.constant 0.0: f32 %alloc = memref.alloc() : memref<32 x f32> - linalg.fill(%f0, %alloc) : f32, memref<32 x f32> + linalg.fill ins(%f0 : f32) outs(%alloc : memref<32 x f32>) %subview = memref.subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32> memref.copy %in, %subview : memref to memref<16 x f32> %0 = vector.transfer_read %alloc[%c0], %f0 {in_bounds = [true]} : memref<32 x f32>, vector<32 x f32> @@ -69,7 +69,7 @@ %alloc = memref.alloc() : memref<128 x i8> %view = memref.view %alloc[%c0][] : memref<128 x i8> to memref<32 x f32> %subview = memref.subview %view[0][16][1] : memref<32 x f32> to memref<16 x f32> - linalg.fill(%f0, %view) : f32, memref<32 x f32> + linalg.fill ins(%f0 : f32) outs(%view : memref<32 x f32>) memref.copy %in, %subview : memref to memref<16 x f32> %0 = vector.transfer_read %view[%c0], %f0 {in_bounds = [true]} : memref<32 x f32>, vector<32 x f32> memref.dealloc %alloc : memref<128 x i8> @@ -129,7 +129,7 @@ %f0 = arith.constant 0.0: f32 %f1 = arith.constant 1.0: f32 %alloc = memref.alloc() : memref<32 x f32> - linalg.fill(%f0, %alloc) : f32, memref<32 x f32> + linalg.fill ins(%f0 : f32) outs(%alloc : memref<32 x f32>) %subview = memref.subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32> memref.copy %in, %subview : memref to memref<16 x f32> "some_interleaved_use"(%subview) : (memref<16 x f32>) -> () diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -934,7 +934,7 @@ linalg.yield %arg2 : f32 } -> tensor %6 = linalg.init_tensor [%arg1] : tensor - %7 = linalg.fill(%cst, %6) : f32, tensor -> tensor + %7 = linalg.fill ins(%cst : f32) outs(%6 : tensor) -> tensor %8 = linalg.generic { indexing_maps = [#map2, #map3], iterator_types = ["parallel", "reduction"] diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir @@ -4,7 +4,7 @@ func @basic_fusion(%arg0: memref, %arg1: memref, %arg2: memref) { %cst = arith.constant 0.000000e+00 : f32 - linalg.fill(%cst, %arg2) : f32, memref + linalg.fill ins(%cst : f32) outs(%arg2 : memref) linalg.matmul {__internal_linalg_transform__ = "basic_fusion"} ins(%arg0, %arg1 : memref, memref) outs(%arg2 : memref) @@ -28,8 +28,9 @@ // CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index // CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index // CHECK-DAG: %[[CST:.+]] = arith.constant 0.0{{.*}} : f32 -// CHECK-DAG: linalg.fill(%[[CST]], %[[ARG2]]) +// CHECK-DAG: linalg.fill // CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original" +// CHECK-SAME: ins(%[[CST]]{{.*}}outs(%[[ARG2]] // CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]] // CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]] // CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = @@ -51,8 +52,9 @@ // CHECK: %[[TILE_N_3:.+]] = affine.min #[[MAP5]](%[[IV1]])[%[[N_2]], %[[N]]] // CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]] // CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N_3]]] -// CHECK: linalg.fill(%[[CST]], %[[SV3_2]]) +// CHECK: linalg.fill // CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer" +// CHECK-SAME: ins(%[[CST]]{{.*}}outs(%[[SV3_2]] // CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] { // CHECK: %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]] // CHECK: %[[SV4:.+]] = memref.subview %[[SV1]][0, %[[IV2]]] @@ -250,7 +252,7 @@ %c64 = arith.constant 64 : index %c16 = arith.constant 16 : index %cst = arith.constant 0.000000e+00 : f32 - linalg.fill(%cst, %arg2) : f32, memref + linalg.fill ins(%cst : f32) outs(%arg2 : memref) %0 = memref.dim %arg0, %c0 : memref %1 = memref.dim %arg1, %c1 : memref %2 = memref.dim %arg0, %c1 : memref @@ -285,7 +287,7 @@ func @basic_conv_fusion(%arg0: memref, %arg1: memref, %arg2: memref) { %cst = arith.constant 0.000000e+00 : f32 - linalg.fill(%cst, %arg2) : f32, memref + linalg.fill ins(%cst : f32) outs(%arg2 : memref) linalg.conv_2d {__internal_linalg_transform__ = "basic_fusion"} ins(%arg1, %arg0 : memref, memref) outs(%arg2 : memref) return diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir --- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir +++ b/mlir/test/Dialect/Linalg/fusion-sequence.mlir @@ -9,7 +9,7 @@ %d0 = memref.dim %arg0, %c0 : memref %d1 = memref.dim %arg1, %c1 : memref %0 = memref.alloc(%d0, %d1) : memref - linalg.fill(%cst, %0) : f32, memref + linalg.fill ins(%cst : f32) outs(%0 : memref) linalg.matmul ins(%arg0, %arg1 : memref, memref) outs(%0 : memref) linalg.generic @@ -43,7 +43,7 @@ // CHECK-DAG: %[[SV_ARG1:.+]] = memref.subview %[[ARG1]][0, %[[IV1]]] // CHECK: %[[SV_TEMP_2:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]] // CHECK: %[[SV_TEMP_3:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]] -// CHECK: linalg.fill(%{{.+}}, %[[SV_TEMP_3]]) +// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_TEMP_3]] // CHECK: linalg.matmul // CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]] // CHECK-SAME: : memref, memref) @@ -70,13 +70,13 @@ %n3 = memref.dim %arg3, %c1 : memref %0 = memref.alloc(%m, %n1) : memref %1 = memref.alloc(%m, %n2) : memref - linalg.fill(%cst, %0) : f32, memref + linalg.fill ins(%cst : f32) outs(%0 : memref) linalg.matmul ins(%arg0, %arg1 : memref, memref) outs(%0 : memref) - linalg.fill(%cst, %1) : f32, memref + linalg.fill ins(%cst : f32) outs(%1 : memref) linalg.matmul ins(%0, %arg2 : memref, memref) outs(%1 : memref) - linalg.fill(%cst, %arg4) : f32, memref + linalg.fill ins(%cst : f32) outs(%arg4 : memref) linalg.matmul ins(%1, %arg3 : memref, memref) outs(%arg4 : memref) return @@ -126,15 +126,15 @@ // CHECK-SAME: [%[[TILE_M_5]], %[[N0]]] // CHECK: %[[SV_ALLOC4:.+]] = memref.subview %[[ALLOC1]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_5]], %[[N1]]] -// CHECK: linalg.fill(%{{.+}}, %[[SV_ALLOC1]]) +// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ALLOC1]] // CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[ARG1]] // CHECK-SAME: : memref, memref) // CHECK-SAME: outs(%[[SV_ALLOC4]] : memref) -// CHECK: linalg.fill(%{{.+}}, %[[SV_ALLOC2]]) +// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ALLOC2]] // CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[ARG2]] // CHECK-SAME: : memref, memref) // CHECK-SAME: outs(%[[SV_ALLOC2]] : memref) -// CHECK: linalg.fill(%{{.+}}, %[[SV_ARG4_2]]) +// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ARG4_2]] // CHECK: linalg.matmul ins(%[[SV_ALLOC3]], %[[ARG3]] // CHECK-SAME: : memref, memref) // CHECK-SAME: outs(%[[SV_ARG4]] : memref) diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir @@ -142,7 +142,7 @@ func @matmul_out_fusion(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { %c0 = arith.constant 0.0 : f32 - %0 = linalg.fill(%c0, %arg0) : f32, tensor -> tensor + %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor) -> tensor %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"} ins(%arg1, %arg2 : tensor, tensor) outs(%0 : tensor) -> tensor @@ -159,7 +159,9 @@ // CHECK: scf.for %[[I:.*]]{{.*}}iter_args(%{{.*}} = %[[ARG0]]) -> (tensor) { // CHECK: scf.for %[[J:.*]] // CHECK: %[[ST:.*]] = tensor.extract_slice %[[ARG0]] -// CHECK: %[[ST_FILL:.*]] = linalg.fill(%[[C0]], %[[ST]]) {__internal_linalg_transform__ = "after_out_fusion_producer"} : f32, tensor -> tensor +// CHECK: %[[ST_FILL:.*]] = linalg.fill +// CHECK-SAME: {__internal_linalg_transform__ = "after_out_fusion_producer"} +// CHECK-SAME: ins(%[[C0]] : f32) outs(%[[ST]] : tensor) -> tensor // CHECK: %[[ST_MM_RES:.*]] = scf.for %[[K:.*]]{{.*}}iter_args(%[[BB:.*]] = %[[ST_FILL]]) -> (tensor) { // CHECK-NOT: fill // CHECK: %[[ST_FILL_SUB:.*]] = tensor.extract_slice %[[BB]][0, 0] diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -685,7 +685,7 @@ %c3 = arith.constant 3 : index %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - linalg.fill(%cst, %arg0) : f32, memref + linalg.fill ins(%cst : f32) outs(%arg0 : memref) %2 = memref.dim %arg1, %c0 : memref %3 = memref.dim %arg1, %c1 : memref %4 = memref.dim %arg2, %c0 : memref diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -156,7 +156,7 @@ // ----- func @generalize_fill(%output: memref, %value : f32) { - linalg.fill(%value, %output) : f32, memref + linalg.fill ins(%value : f32) outs(%output : memref) return } diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -222,7 +222,7 @@ // ----- func @generalize_fill_0d(%value: f64, %O: tensor) -> tensor { - %0 = linalg.fill_tensor ins(%value: f64) outs(%O : tensor) -> tensor + %0 = linalg.fill ins(%value: f64) outs(%O : tensor) -> tensor return %0: tensor } @@ -236,7 +236,7 @@ // ----- func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) { - linalg.fill_tensor ins(%value: f64) outs(%O : memref<16x32xf32>) + linalg.fill ins(%value: f64) outs(%O : memref<16x32xf32>) return } diff --git a/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir b/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir --- a/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir +++ b/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir @@ -4,7 +4,7 @@ // CHECK-SAME: %[[IN:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> { // CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[INIT:.*]] = linalg.init_tensor [1, 32, 32, 1] : tensor<1x32x32x1xf32> -// CHECK: %[[FILL:.*]] = linalg.fill(%[[C0]], %[[INIT]]) : f32, tensor<1x32x32x1xf32> -> tensor<1x32x32x1xf32> +// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x32x32x1xf32>) -> tensor<1x32x32x1xf32> // CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]][0, 2, 2, 0] [1, 28, 28, 1] [1, 1, 1, 1] : tensor<1x28x28x1xf32> into tensor<1x32x32x1xf32> // CHECK: return %[[PADDED]] : tensor<1x32x32x1xf32> func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> { @@ -29,7 +29,7 @@ // CHECK: %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32> // CHECK: %[[OUT_DIM3:.*]] = arith.addi %[[DIM3]], %[[OFFSET]] : index // CHECK: %[[INIT:.*]] = linalg.init_tensor [4, %[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]] : tensor<4x?x?x?xf32> -// CHECK: %[[FILL:.*]] = linalg.fill(%[[CST]], %[[INIT]]) : f32, tensor<4x?x?x?xf32> -> tensor<4x?x?x?xf32> +// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<4x?x?x?xf32>) -> tensor<4x?x?x?xf32> // CHECK: %[[DIM1_1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32> // CHECK: %[[DIM3_1:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32> // CHECK: %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[OFFSET]], %[[C0]]] [4, %[[DIM1_1]], 2, %[[DIM3_1]]] [1, 1, 1, 1] : tensor<4x?x2x?xf32> into tensor<4x?x?x?xf32> diff --git a/mlir/test/Dialect/Linalg/hoist-padding.mlir b/mlir/test/Dialect/Linalg/hoist-padding.mlir --- a/mlir/test/Dialect/Linalg/hoist-padding.mlir +++ b/mlir/test/Dialect/Linalg/hoist-padding.mlir @@ -377,7 +377,7 @@ ^bb0(%arg5: index, %arg6: index): tensor.yield %cst : f32 } : tensor to tensor<5x24xf32> - %5 = linalg.fill(%cst, %4) : f32, tensor<5x24xf32> -> tensor<5x24xf32> + %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<5x24xf32>) -> tensor<5x24xf32> %6 = tensor.extract_slice %5[0, 0] [%1, 24] [1, 1] : tensor<5x24xf32> to tensor // Check the first input operand is hoisted by one loop nest. diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -369,16 +369,7 @@ { %0 = linalg.init_tensor [%arg0, %arg1] : tensor // expected-error @+1 {{expected the number of results (0) to be equal to the number of output tensors (1)}} - linalg.fill(%arg2, %0) : f32, tensor -} - -// ----- - -func @illegal_fill_memref_with_return(%arg0 : memref, %arg1 : f32) -> tensor -{ - // expected-error @+1 {{op expected the number of results (1) to be equal to the number of output tensors (0)}} - %0 = linalg.fill(%arg1, %arg0) : f32, memref -> tensor - return %0 : tensor + linalg.fill ins(%arg2 : f32) outs(%0 : tensor) } // ----- @@ -387,7 +378,7 @@ (%arg0 : memref, %arg1 : f32) -> tensor { // expected-error @+1 {{expected the number of results (1) to be equal to the number of output tensors (0)}} - %0 = linalg.fill(%arg1, %arg0) : f32, memref -> tensor + %0 = linalg.fill ins(%arg1 : f32) outs(%arg0 : memref) -> tensor return %0 : tensor } @@ -396,8 +387,8 @@ func @illegal_fill_tensor_with_memref_return (%arg0 : tensor, %arg1 : f32) -> memref { - // expected-error @+1 {{op result #0 must be ranked tensor of any type values, but got 'memref'}} - %0 = linalg.fill(%arg1, %arg0) : f32, tensor -> memref + // expected-error @+1 {{result #0 must be ranked tensor of any type values, but got 'memref'}} + %0 = linalg.fill ins(%arg1 : f32) outs(%arg0 : tensor) -> memref return %0 : memref } diff --git a/mlir/test/Dialect/Linalg/library-calls.mlir b/mlir/test/Dialect/Linalg/library-calls.mlir --- a/mlir/test/Dialect/Linalg/library-calls.mlir +++ b/mlir/test/Dialect/Linalg/library-calls.mlir @@ -14,7 +14,7 @@ %C = memref.alloc(%x, %y) : memref // CHECK: call @linalg_fill_f32_viewsxsxf32({{.*}}) : (f32, memref) - linalg.fill(%f0, %C) : f32, memref + linalg.fill ins(%f0 : f32) outs(%C : memref) // CHECK: call @linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32({{.*}}) : (memref, memref, memref) -> () linalg.matmul ins(%A, %B: memref, memref) diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -166,7 +166,7 @@ // CHECKPARALLEL: store %[[res]], %{{.*}}[] : memref func @fill_view(%arg0: memref, %arg1: f32) { - linalg.fill(%arg1, %arg0) : f32, memref + linalg.fill ins(%arg1 : f32) outs(%arg0 : memref) return } // CHECK-LABEL: func @fill_view( @@ -180,7 +180,7 @@ // CHECKPARALLEL: store %{{.*}}, %{{.*}}[%{{.*}}] : memref func @fill_view0(%arg0: memref, %arg1: f32) { - linalg.fill(%arg1, %arg0) : f32, memref + linalg.fill ins(%arg1 : f32) outs(%arg0 : memref) return } // CHECK-LABEL: func @fill_view0(%{{.*}}: memref, %{{.*}}: f32) { @@ -190,7 +190,7 @@ // CHECKPARALLEL: store %{{.*}}, %{{.*}}[] : memref func @fill_view3(%arg0: memref, %arg1: f32) { - linalg.fill(%arg1, %arg0) : f32, memref + linalg.fill ins(%arg1 : f32) outs(%arg0 : memref) return } // CHECK-LABEL: func @fill_view3( diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -4,7 +4,7 @@ func @depthwise_conv_2d_nhwc_hwcm_tensor(%input: tensor<2x4x5x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x3x4x2x3xf32> { %zero = arith.constant 0.000000e+00 : f32 %init = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32> - %fill = linalg.fill(%zero, %init) : f32, tensor<2x3x4x2x3xf32> -> tensor<2x3x4x2x3xf32> + %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32> // CHECK: %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwcm // CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) @@ -70,7 +70,7 @@ func @depthwise_conv_2d_nhwc_hwcm_tensor_dilated(%input: tensor<2x8x9x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x6x7x2x3xf32> { %zero = arith.constant 0.000000e+00 : f32 %init = linalg.init_tensor [2, 6, 7, 2, 3] : tensor<2x6x7x2x3xf32> - %fill = linalg.fill(%zero, %init) : f32, tensor<2x6x7x2x3xf32> -> tensor<2x6x7x2x3xf32> + %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x6x7x2x3xf32>) -> tensor<2x6x7x2x3xf32> // CHECK: %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwcm // CHECK-SAME: {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<2x8x9x2xf32>, tensor<2x2x2x3xf32>) @@ -236,7 +236,7 @@ %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32> %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32> %cst = arith.constant 0.000000e+00 : f32 - %fill = linalg.fill(%cst, %init) : f32, tensor<1x2x2x1xf32> -> tensor<1x2x2x1xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> %res = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>) outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> @@ -270,7 +270,7 @@ %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32> %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32> %cst = arith.constant 0.000000e+00 : f32 - %fill = linalg.fill(%cst, %init) : f32, tensor<1x2x2x1xf32> -> tensor<1x2x2x1xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>) outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> @@ -289,7 +289,7 @@ %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32> %init = linalg.init_tensor [1, 1, 2, 2] : tensor<1x1x2x2xf32> %cst = arith.constant 0.000000e+00 : f32 - %fill = linalg.fill(%cst, %init) : f32, tensor<1x1x2x2xf32> -> tensor<1x1x2x2xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32> %res = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%input, %fake: tensor<1x1x4x4xf32>, tensor<3x3xf32>) outs(%fill: tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32> @@ -323,7 +323,7 @@ %fake = linalg.init_tensor [3, 3] : tensor<3x3xi8> %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi8> %cst = arith.constant 0 : i8 - %fill = linalg.fill(%cst, %init) : i8, tensor<1x2x2x1xi8> -> tensor<1x2x2x1xi8> + %fill = linalg.fill ins(%cst : i8) outs(%init : tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8> %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%input, %fake: tensor<1x4x4x1xi8>, tensor<3x3xi8>) outs(%fill: tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8> @@ -357,7 +357,7 @@ %fake = linalg.init_tensor [3, 3] : tensor<3x3xi16> %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi16> %cst = arith.constant 0 : i16 - %fill = linalg.fill(%cst, %init) : i16, tensor<1x2x2x1xi16> -> tensor<1x2x2x1xi16> + %fill = linalg.fill ins(%cst : i16) outs(%init : tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16> %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%input, %fake: tensor<1x4x4x1xi16>, tensor<3x3xi16>) outs(%fill: tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16> @@ -391,7 +391,7 @@ %fake = linalg.init_tensor [3, 3] : tensor<3x3xi32> %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi32> %cst = arith.constant 0 : i32 - %fill = linalg.fill(%cst, %init) : i32, tensor<1x2x2x1xi32> -> tensor<1x2x2x1xi32> + %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32> %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>) outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32> @@ -426,7 +426,7 @@ %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32> %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32> %cst = arith.constant 0.000000e+00 : f32 - %fill = linalg.fill(%cst, %init) : f32, tensor<1x2x2x1xf32> -> tensor<1x2x2x1xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> %res = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>) outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> @@ -460,7 +460,7 @@ %fake = linalg.init_tensor [3, 3, 3] : tensor<3x3x3xf32> %init = linalg.init_tensor [1, 2, 2, 2, 1] : tensor<1x2x2x2x1xf32> %cst = arith.constant 0.000000e+00 : f32 - %fill = linalg.fill(%cst, %init) : f32, tensor<1x2x2x2x1xf32> -> tensor<1x2x2x2x1xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> %res = linalg.pooling_ndhwc_sum {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} ins(%input, %fake: tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>) outs(%fill: tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> @@ -494,7 +494,7 @@ %fake = linalg.init_tensor [3, 3, 3] : tensor<3x3x3xf32> %init = linalg.init_tensor [1, 2, 2, 2, 1] : tensor<1x2x2x2x1xf32> %cst = arith.constant 0.000000e+00 : f32 - %fill = linalg.fill(%cst, %init) : f32, tensor<1x2x2x2x1xf32> -> tensor<1x2x2x2x1xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> %res = linalg.pooling_ndhwc_max {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} ins(%input, %fake: tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>) outs(%fill: tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> @@ -528,7 +528,7 @@ %fake = linalg.init_tensor [3, 3, 3] : tensor<3x3x3xf32> %init = linalg.init_tensor [1, 2, 2, 2, 1] : tensor<1x2x2x2x1xf32> %cst = arith.constant 0.000000e+00 : f32 - %fill = linalg.fill(%cst, %init) : f32, tensor<1x2x2x2x1xf32> -> tensor<1x2x2x2x1xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> %res = linalg.pooling_ndhwc_min {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} ins(%input, %fake: tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>) outs(%fill: tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> diff --git a/mlir/test/Dialect/Linalg/pad.mlir b/mlir/test/Dialect/Linalg/pad.mlir --- a/mlir/test/Dialect/Linalg/pad.mlir +++ b/mlir/test/Dialect/Linalg/pad.mlir @@ -173,11 +173,11 @@ // Check both fill operations are padded by the same pad tensor operation. // FILL: %[[T0:.*]] = tensor.pad - // FILL: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) - // FILL: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]]) + // FILL: %[[T1:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T0]] + // FILL: %[[T2:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T1]] // FILL: = tensor.extract_slice %[[T2]] - %1 = linalg.fill(%cst, %0) : f32, tensor -> tensor - %2 = linalg.fill(%cst, %1) : f32, tensor -> tensor + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor return %2 : tensor } @@ -198,15 +198,15 @@ // MATMUL-SAME: [0, 0] // MATMUL-SAME: [%[[SIZE]], %[[SIZE]]] // MATMUL: %[[T1:.*]] = tensor.pad %[[T0]] - // MATMUL: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]] - // MATMUL: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]] + // MATMUL: %[[T2:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T1]] + // MATMUL: %[[T3:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T2]] %0 = tensor.extract_slice %arg0[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor %1 = tensor.pad %0 low[0, 0] high[%iv0, %iv0] { - ^bb0(%arg3: index, %arg4: index): + ^bb0(%arg3: index, %arg4: index): tensor.yield %cst : f32 } : tensor to tensor<64x64xf32> - %2 = linalg.fill(%cst, %1) : f32, tensor<64x64xf32> -> tensor<64x64xf32> - %3 = linalg.fill(%cst, %2) : f32, tensor<64x64xf32> -> tensor<64x64xf32> + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<64x64xf32>) -> tensor<64x64xf32> + %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<64x64xf32>) -> tensor<64x64xf32> %4 = tensor.extract_slice %3[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor // Check there are no additional pad tensor operations. @@ -234,10 +234,10 @@ %size = affine.min #map0()[%iv0] %0 = tensor.extract_slice %arg0[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor %1 = tensor.pad %0 low[0, 0] high[%iv0, %iv0] { - ^bb0(%arg3: index, %arg4: index): + ^bb0(%arg3: index, %arg4: index): tensor.yield %cst : f32 } : tensor to tensor<64x64xf32> - %2 = linalg.fill(%cst, %1) : f32, tensor<64x64xf32> -> tensor<64x64xf32> + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<64x64xf32>) -> tensor<64x64xf32> %4 = tensor.extract_slice %2[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor // Different padding values prevent composing the paddings (42.0 vs. 0.0). @@ -259,10 +259,10 @@ %size = affine.min #map0()[%iv0] %0 = tensor.extract_slice %arg0[0, 0] [%iv0, %iv0] [1, 1] : tensor<64x64xf32> to tensor %1 = tensor.pad %0 low[0, 0] high[%iv0, %iv0] { - ^bb0(%arg3: index, %arg4: index): + ^bb0(%arg3: index, %arg4: index): tensor.yield %cst : f32 } : tensor to tensor<64x64xf32> - %2 = linalg.fill(%cst, %1) : f32, tensor<64x64xf32> -> tensor<64x64xf32> + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<64x64xf32>) -> tensor<64x64xf32> %4 = tensor.extract_slice %2[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor // Different dynamic sizes prevent composing the paddings (%iv0 vs %size). @@ -284,10 +284,10 @@ %size = affine.min #map0()[%iv0] %0 = tensor.extract_slice %arg0[0, 0, 0] [%size, %size, 1] [1, 1, 1] : tensor<64x64x1xf32> to tensor %1 = tensor.pad %0 low[0, 0] high[%iv0, %iv0] { - ^bb0(%arg3: index, %arg4: index): + ^bb0(%arg3: index, %arg4: index): tensor.yield %cst : f32 } : tensor to tensor<64x64xf32> - %2 = linalg.fill(%cst, %1) : f32, tensor<64x64xf32> -> tensor<64x64xf32> + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<64x64xf32>) -> tensor<64x64xf32> %3 = tensor.extract_slice %2[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor // Different dynamic ranks prevent composing the paddings ([%size, %size, 1] vs [%size, %size]). @@ -309,10 +309,10 @@ %size = affine.min #map0()[%iv0] %0 = tensor.extract_slice %arg0[0, 0] [%size, %size] [1, 1] : tensor<62x62xf32> to tensor %1 = tensor.pad %0 low[0, 0] high[%iv0, %iv0] { - ^bb0(%arg3: index, %arg4: index): + ^bb0(%arg3: index, %arg4: index): tensor.yield %cst : f32 } : tensor to tensor<62x62xf32> - %2 = linalg.fill(%cst, %1) : f32, tensor<62x62xf32> -> tensor<62x62xf32> + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<62x62xf32>) -> tensor<62x62xf32> %4 = tensor.extract_slice %2[0, 0] [%size, %size] [1, 1] : tensor<62x62xf32> to tensor // Different static sizes prevent composing the paddings (62 vs 64 derived from #map0). @@ -340,8 +340,8 @@ %1 = tensor.extract_slice %arg1[0, 0] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> // Check only the fill output operand is padded. - // FILL: %[[T6:.*]] = linalg.fill(%[[ARG0]], %[[T1]] - %2 = linalg.fill(%arg0, %1) : f32, tensor<4x?xf32> -> tensor<4x?xf32> + // FILL: %[[T6:.*]] = linalg.fill ins(%[[ARG0]]{{.*}}outs(%[[T1]] + %2 = linalg.fill ins(%arg0 : f32) outs(%1 : tensor<4x?xf32>) -> tensor<4x?xf32> %3 = tensor.insert_slice %2 into %arg1[0, 0] [4, %0] [1, 1] : tensor<4x?xf32> into tensor<24x12xf32> return %3 : tensor<24x12xf32> } @@ -466,9 +466,9 @@ // Check the fill is padded despite the rank-reducing slice operation. // FILL: %[[T0:.*]] = tensor.pad - // FILL: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // FILL: %[[T1:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T0]] // FILL-SAME: tensor<1x64x64xf32> // FILL: = tensor.extract_slice %[[T1]] - %1 = linalg.fill(%cst, %0) : f32, tensor<1x?x?xf32> -> tensor<1x?x?xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32> return %1 : tensor<1x?x?xf32> } diff --git a/mlir/test/Dialect/Linalg/pad_fusion.mlir b/mlir/test/Dialect/Linalg/pad_fusion.mlir --- a/mlir/test/Dialect/Linalg/pad_fusion.mlir +++ b/mlir/test/Dialect/Linalg/pad_fusion.mlir @@ -38,7 +38,7 @@ // CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]] // CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[SOURCE_D1]]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[TARGET_D0]], %[[TARGET_D1]]] -// CHECK: %[[FILL:.+]] = linalg.fill(%[[ARG5]], %[[INIT]]) +// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ARG5]]{{.*}}outs(%[[INIT]] // CHECK-DAG: %[[SIZE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]] // CHECK-DAG: %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]] // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]] @@ -82,7 +82,7 @@ // CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]] // CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]], %[[SOURCE_D1]]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [49, %[[TARGET_D1]]] -// CHECK: %[[FILL:.+]] = linalg.fill(%[[ARG3]], %[[INIT]]) +// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ARG3]]{{.*}}outs(%[[INIT]] // CHECK-DAG: %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]] // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]] // CHECK-SAME: [3, %[[ARG1]]] [42, %[[SIZE_D1]]] [1, 1] diff --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir --- a/mlir/test/Dialect/Linalg/promotion_options.mlir +++ b/mlir/test/Dialect/Linalg/promotion_options.mlir @@ -23,9 +23,9 @@ // CHECK: %[[T19:.+]] = memref.subview %[[T18]] // CHECK: %[[T20:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref // CHECK: %[[T21:.+]] = memref.subview %[[T20]] -// CHECK: linalg.fill(%[[C42]], %[[T19]]) +// CHECK: linalg.fill ins(%[[C42]]{{.*}}outs(%[[T19]] // CHECK: memref.copy %[[T7]], %[[T19]] -// CHECK: linalg.fill(%[[C42]], %[[T21]]) +// CHECK: linalg.fill ins(%[[C42]]{{.*}}outs(%[[T21]] // CHECK: memref.copy %[[T17]], %[[T21]] // CHECK: linalg.matmul ins(%[[T19]], %[[T12]]{{.*}} outs(%[[T21]] // CHECK-NOT: linalg.fill diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -63,12 +63,12 @@ // ----- func @fill_view(%arg0: memref, %arg1: f32) { - linalg.fill(%arg1, %arg0) : f32, memref + linalg.fill ins(%arg1 : f32) outs(%arg0 : memref) return } // CHECK-LABEL: func @fill_view( // CHECK: %{{.*}}: memref, %{{.*}}: f32) { -// CHECK: linalg.fill(%{{.*}}, %{{.*}}) : f32, memref +// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref) // ----- @@ -84,12 +84,12 @@ func @fill_view3(%arg0: memref, %arg1: f32) { - linalg.fill(%arg1, %arg0) : f32, memref + linalg.fill ins(%arg1 : f32) outs(%arg0 : memref) return } // CHECK-LABEL: func @fill_view3( // CHECK: %{{.*}}: memref, %{{.*}}: f32) { -// CHECK: linalg.fill(%{{.*}}, %{{.*}}) : f32, memref +// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref) // ----- @@ -208,9 +208,9 @@ -> (tensor, tensor) { %c0 = arith.constant 0 : index %0 = linalg.init_tensor [] : tensor - %1 = linalg.fill(%arg2, %0) : i32, tensor -> tensor + %1 = linalg.fill ins(%arg2 : i32) outs(%0 : tensor) -> tensor %2 = linalg.init_tensor [] : tensor - %3 = linalg.fill(%arg2, %2) : i32, tensor -> tensor + %3 = linalg.fill ins(%arg2 : i32) outs(%2 : tensor) -> tensor %4:2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} @@ -346,7 +346,7 @@ func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor { %0 = linalg.init_tensor [%arg0, %arg1] : tensor - %1 = linalg.fill(%arg2, %0) : f32, tensor -> tensor + %1 = linalg.fill ins(%arg2 : f32) outs(%0 : tensor) -> tensor return %1 : tensor } -// CHECK: %{{.+}} = linalg.fill(%{{.+}}, %{{.+}}) : f32, tensor -> tensor +// CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor) -> tensor diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir @@ -7,7 +7,7 @@ %d0 = tensor.dim %arg0, %c0 : tensor %d1 = tensor.dim %arg1, %c1 : tensor %init = linalg.init_tensor [%d0, %d1] : tensor - %fill = linalg.fill(%cst, %init) : f32, tensor -> tensor + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor %result = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) outs(%fill : tensor) -> tensor return %result : tensor diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir @@ -17,7 +17,7 @@ %c24 = arith.constant 24 : index %c4 = arith.constant 4 : index %cst = arith.constant 0.000000e+00 : f32 - %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32> + %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<24x12xf32>) -> tensor<24x12xf32> // MATMUL: scf.for %[[IV0:[0-9a-zA-Z]*]] = // MATMUL: scf.for %[[IV1:[0-9a-zA-Z]*]] = @@ -31,7 +31,7 @@ // MATMUL: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] // MATMUL-SAME: %[[IV1]], %[[IV2]] // MATMUL-SAME: %[[UB1]], %[[UB2]] - // MATMUL: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // MATMUL: %[[T1:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T0]] // MATMUL: %{{.*}} = linalg.matmul ins(%[[T1]] %1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> return %1 : tensor<24x25xf32> @@ -55,7 +55,7 @@ %c24 = arith.constant 24 : index %c4 = arith.constant 4 : index %cst = arith.constant 0.000000e+00 : f32 - %0 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32> + %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> // Update the iteration argument of the outermost tile loop. // MATMUL: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] @@ -67,7 +67,7 @@ // MATMUL: %[[T0:.*]] = tensor.extract_slice %[[ARG4]] // MATMUL-SAME: %[[IV1]], %[[IV0]] // MATMUL-SAME: %[[TS1]], %[[TS0]] - // MATMUL: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // MATMUL: %[[T1:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T0]] // MATMUL: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]] // Check there is an extract/insert slice pair for the output operand. @@ -184,19 +184,19 @@ %c24 = arith.constant 24 : index %c4 = arith.constant 4 : index %cst = arith.constant 0.000000e+00 : f32 - %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32> - %1 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32> + %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<24x12xf32>) -> tensor<24x12xf32> + %1 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> // Fuse both producers to the appropriate tile loops. // MATMUL: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] // MATMUL: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]] // MATMUL: %[[T0:.*]] = tensor.extract_slice %[[ARG4]] // MATMUL-SAME: %[[IV1]], %[[IV0]] - // MATMUL: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // MATMUL: %[[T1:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T0]] // MATMUL: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]] // MATMUL: %[[T2:.*]] = tensor.extract_slice %[[ARG0]] // MATMUL-SAME: %[[IV1]], %[[IV2]] - // MATMUL: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]]) + // MATMUL: %[[T3:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T2]] // MATMUL: %[[T4:.*]] = tensor.extract_slice %[[ARG5]] // MATMUL: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[T4]] %2 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%1 : tensor<24x25xf32>) -> tensor<24x25xf32> @@ -255,11 +255,11 @@ func @fuse_outermost_reduction(%arg0: tensor<10x17xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> { %cst = arith.constant 0.000000e+00 : f32 - %0 = linalg.fill(%cst, %arg0) : f32, tensor<10x17xf32> -> tensor<10x17xf32> + %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<10x17xf32>) -> tensor<10x17xf32> // Cannot fuse the output fill since the reduction loop is the outermost loop. - // GENERIC: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG1]]) - %1 = linalg.fill(%cst, %arg1) : f32, tensor<10xf32> -> tensor<10xf32> + // GENERIC: %[[T0:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[ARG1]] + %1 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<10xf32>) -> tensor<10xf32> // GENERIC: scf.for %[[IV0:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[T0]] // GENERIC: scf.for %[[IV1:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] @@ -267,7 +267,7 @@ // MATMUL the input fill has been fused. // GENERIC: %[[T1:.*]] = tensor.extract_slice %[[ARG0]] // GENERIC-SAME: %[[IV1]], %[[IV0]] - // GENERIC: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]]) + // GENERIC: %[[T2:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T1]] // GENERIC: %[[T3:.*]] = tensor.extract_slice %[[ARG3]] // GENERIC-SAME: %[[IV1]] // GENERIC: linalg.generic {{.*}} ins(%[[T2]] {{.*}} outs(%[[T3]] @@ -298,7 +298,7 @@ // GENERIC-DAG: %[[C8:.*]] = arith.constant 8 : index // GENERIC-DAG: %[[C10:.*]] = arith.constant 10 : index %cst = arith.constant 0.000000e+00 : f32 - %0 = linalg.fill(%cst, %arg0) : f32, tensor<10x17xf32> -> tensor<10x17xf32> + %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<10x17xf32>) -> tensor<10x17xf32> // GENERIC: scf.for %[[IV0:[0-9a-zA-Z]*]] = %[[C0]] to %[[C8]] step %[[C4]] // GENERIC: scf.for %[[IV1:[0-9a-zA-Z]*]] = %[[C0]] to %[[C10]] step %[[C5]] @@ -313,7 +313,7 @@ // GENERIC: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] // GENERIC-SAME: %[[IV1]], %[[SUM]] // GENERIC-SAME: , %[[UB1]] - // GENERIC: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // GENERIC: %[[T1:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T0]] %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<10x17xf32>) outs(%arg1 : tensor<10x8xf32>) { ^bb0(%arg2: f32, %arg3: f32): %2 = arith.addf %arg2, %arg3 : f32 diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir @@ -15,8 +15,8 @@ %cst = arith.constant 1.0 : f32 // Do not tile the filter fill since the filter dimensions are not tiled. - // CONV: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]]) - %0 = linalg.fill(%cst, %arg0) : f32, tensor<2x2xf32> -> tensor<2x2xf32> + // CONV: %[[T0:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[ARG0]] + %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<2x2xf32>) -> tensor<2x2xf32> // Fuse all other operations. // CONV: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[ARG4]] @@ -26,24 +26,24 @@ // CONV-SAME: %[[IV0]], %[[IV1]] // CONV: %[[T2:.*]] = tensor.extract_slice %[[ARG2]] // CONV-SAME: %[[IV0]], %[[IV1]] - // CONV: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]]) + // CONV: %[[T3:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T2]] // CONV: %[[T4:.*]] = linalg.conv_2d ins(%[[T1]], %[[T0]] : {{.*}} outs(%[[T3]] - %1 = linalg.fill(%cst, %arg2) : f32, tensor<10x10xf32> -> tensor<10x10xf32> + %1 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<10x10xf32>) -> tensor<10x10xf32> %2 = linalg.conv_2d ins(%arg1, %0 : tensor<11x11xf32>, tensor<2x2xf32>) outs(%1 : tensor<10x10xf32>) -> tensor<10x10xf32> // CONV: %[[T5:.*]] = tensor.extract_slice %[[ARG3]] // CONV-SAME: %[[IV0]], %[[IV1]] - // CONV: %[[T6:.*]] = linalg.fill(%{{.*}}, %[[T5]]) + // CONV: %[[T6:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T5]] // CONV: %[[T7:.*]] = linalg.conv_2d ins(%[[T4]], %[[T0]] : {{.*}} outs(%[[T6]] - %3 = linalg.fill(%cst, %arg3) : f32, tensor<9x9xf32> -> tensor<9x9xf32> + %3 = linalg.fill ins(%cst : f32) outs(%arg3 : tensor<9x9xf32>) -> tensor<9x9xf32> %4 = linalg.conv_2d ins(%2, %0 : tensor<10x10xf32>, tensor<2x2xf32>) outs(%3 : tensor<9x9xf32>) -> tensor<9x9xf32> // Use the argument passed in by iteration argument. // CONV: %[[T8:.*]] = tensor.extract_slice %[[ARG6]] // CONV-SAME: %[[IV0]], %[[IV1]] - // CONV: %[[T9:.*]] = linalg.fill(%{{.*}}, %[[T8]]) + // CONV: %[[T9:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T8]] // CONV: %[[T5:.*]] = linalg.conv_2d ins(%[[T7]], %[[T0]] {{.*}} outs(%[[T9]] - %5 = linalg.fill(%cst, %arg4) : f32, tensor<8x8xf32> -> tensor<8x8xf32> + %5 = linalg.fill ins(%cst : f32) outs(%arg4 : tensor<8x8xf32>) -> tensor<8x8xf32> %6 = linalg.conv_2d ins(%4, %0 : tensor<9x9xf32>, tensor<2x2xf32>) outs(%5 : tensor<8x8xf32>) -> tensor<8x8xf32> return %6 : tensor<8x8xf32> } @@ -61,8 +61,8 @@ %cst = arith.constant 0.000000e+00 : f32 // Do not tile rhs fill of the producer matmul since none of its loop dimension is tiled. - // MATMUL: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]]) - %0 = linalg.fill(%cst, %arg0) : f32, tensor<8x8xf32> -> tensor<8x8xf32> + // MATMUL: %[[T0:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[ARG0]] + %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<8x8xf32>) -> tensor<8x8xf32> // MATMUL: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG1:.*]] = %[[ARG0]] // MATMUL: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[ARG1]] @@ -70,14 +70,14 @@ // Only the outermost loop of the producer matmul is tiled. // MATMUL: %[[T1:.*]] = tensor.extract_slice %[[ARG0]] // MATMUL-SAME: %[[IV0]], 0 - // MATMUL: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]]) + // MATMUL: %[[T2:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T1]] // MATMUL: %[[T3:.*]] = linalg.matmul ins(%[[T2]], %[[T0]] {{.*}} %1 = linalg.matmul ins(%0, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32> // Use the argument passed in by iteration argument. // MATMUL: %[[T4:.*]] = tensor.extract_slice %[[ARG2]] // MATMUL-SAME: %[[IV0]], %[[IV1]] - // MATMUL: %[[T5:.*]] = linalg.fill(%{{.*}}, %[[T4]]) + // MATMUL: %[[T5:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T4]] // MATMUL: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[T5]] %2 = linalg.matmul ins(%1, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32> return %2 : tensor<8x8xf32> diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -75,7 +75,7 @@ %cst = arith.constant 0.0 : f32 %init = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32> - %fill = linalg.fill(%cst, %init) : f32, tensor<1x112x112x32xf32> -> tensor<1x112x112x32xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> %conv = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} @@ -119,7 +119,7 @@ // CHECK-SAME: (%[[INPUT:.+]]: tensor<1x225x225x3xf32>, %[[FILTER:.+]]: tensor<3x3x3x32xf32>, %[[ELEM:.+]]: tensor<1x112x112x32xf32>) // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32> -// CHECK-NEXT: %[[FILL:.+]] = linalg.fill(%cst, %[[INIT]]) : f32, tensor<1x112x112x32xf32> -> tensor<1x112x112x32xf32> +// CHECK-NEXT: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> // CHECK-NEXT: scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG0:.+]] = %[[FILL]]) // CHECK-NEXT: %[[OFFSET_H:.+]] = affine.apply #[[MAP0]](%[[IV0]]) @@ -157,7 +157,7 @@ %oc = tensor.dim %elementwise, %c3 : tensor %init = linalg.init_tensor [%n, %oh, %ow, %oc] : tensor - %fill = linalg.fill(%cst, %init) : f32, tensor -> tensor + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor %conv = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} @@ -226,7 +226,7 @@ // CHECK-DAG: %[[ELEM_OC:.+]] = tensor.dim %[[ELEM]], %[[C3]] : tensor // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[ELEM_N]], %[[ELEM_OH]], %[[ELEM_OW]], %[[ELEM_OC]]] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill(%cst, %[[INIT]]) : f32, tensor -> tensor +// CHECK: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor) -> tensor // CHECK-DAG: %[[FILTER_H:.+]] = tensor.dim %[[FILTER]], %[[C0]] : tensor // CHECK-DAG: %[[FILTER_W:.+]] = tensor.dim %[[FILTER]], %[[C1]] : tensor @@ -310,7 +310,7 @@ tensor.yield %zero : f32 } : tensor<58x1xf32> to tensor<64x128xf32> - %fill = linalg.fill(%zero, %large_input) : f32, tensor<64x128xf32> -> tensor<64x128xf32> + %fill = linalg.fill ins(%zero : f32) outs(%large_input : tensor<64x128xf32>) -> tensor<64x128xf32> %for0 = scf.for %iv0 = %c0 to %d0 step %c16 iter_args(%arg0 = %fill) -> tensor<64x128xf32> { %for1 = scf.for %iv1 = %c0 to %d1 step %c32 iter_args(%arg1 = %arg0) -> tensor<64x128xf32> { diff --git a/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir --- a/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir +++ b/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir @@ -24,7 +24,7 @@ // CHECK: %[[STEPX:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSX]], %[[C8]]] // CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor) { // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TC1]] -// CHECK: %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[SLICE]]) +// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[SLICE]] // CHECK: %[[sTD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[FILL]]) -> (tensor) { // CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor to tensor // CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor to tensor @@ -42,7 +42,7 @@ %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg1, %c1 : tensor %2 = linalg.init_tensor [%0, %1] : tensor - %3 = linalg.fill(%cst, %2) : f32, tensor -> tensor + %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor) -> tensor %4 = linalg.matmul {__internal_linalg_transform__ = "tensors_fuse_distribute1"} ins(%arg0, %arg1: tensor, tensor) outs(%3: tensor) diff --git a/mlir/test/Dialect/Linalg/tile-scalarize-dynamic-dims.mlir b/mlir/test/Dialect/Linalg/tile-scalarize-dynamic-dims.mlir --- a/mlir/test/Dialect/Linalg/tile-scalarize-dynamic-dims.mlir +++ b/mlir/test/Dialect/Linalg/tile-scalarize-dynamic-dims.mlir @@ -43,7 +43,7 @@ %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c32 = arith.constant 32 : index - %0 = linalg.fill(%cst, %arg2) : f32, tensor<257x258xf32> -> tensor<257x258xf32> + %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<257x258xf32>) -> tensor<257x258xf32> %1 = scf.for %arg3 = %c0 to %c257 step %c64 iter_args(%arg4 = %0) -> (tensor<257x258xf32>) { %2 = affine.min #map0(%arg3) %3 = tensor.extract_slice %arg0[%arg3, 0] [%2, 259] [1, 1] : tensor<257x259xf32> to tensor diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -254,35 +254,35 @@ // TILE-234: linalg.dot ins(%[[sAi]], %[[sBi]]{{.*}} outs( func @fill_static(%arg0: memref<127x99xf32>, %arg1: f32) { - linalg.fill(%arg1, %arg0) : f32, memref<127x99xf32> + linalg.fill ins(%arg1 : f32) outs(%arg0 : memref<127x99xf32>) return } // TILE-2-LABEL: func @fill_static // TILE-2: for // TILE-2-NOT: for // TILE-2: memref.subview{{.*}} : memref<127x99xf32> -// TILE-2: linalg.fill{{.*}} : f32, memref +// TILE-2: linalg.fill{{.*}} : memref // TILE-02-LABEL: func @fill_static // TILE-02: for // TILE-02-NOT: for // TILE-02: memref.subview{{.*}} : memref<127x99xf32> -// TILE-02: linalg.fill{{.*}} : f32, memref<127x?xf32, #[[$stride_99_1_layout_map]]> +// TILE-02: linalg.fill{{.*}} : memref<127x?xf32, #[[$stride_99_1_layout_map]]> // TILE-002-LABEL: func @fill_static // TILE-002-NOT: for -// TILE-002: linalg.fill{{.*}} f32, memref<127x99xf32> +// TILE-002: linalg.fill{{.*}} : memref<127x99xf32> // TILE-234-LABEL: func @fill_static // TILE-234: for // TILE-234: for // TILE-234-NOT: for // TILE-234: memref.subview{{.*}} : memref<127x99xf32> -// TILE-234: linalg.fill{{.*}} : f32, memref +// TILE-234: linalg.fill{{.*}} : memref func @fill(%arg0: memref, %arg1: f32) { - linalg.fill(%arg1, %arg0) : f32, memref + linalg.fill ins(%arg1 : f32) outs(%arg0 : memref) return } // TILE-2-LABEL: func @fill @@ -318,7 +318,7 @@ linalg.generic #pointwise_2d_trait ins(%arg0, %arg1 : memref, memref) outs(%arg2 : memref) { - ^bb0(%arg4: f32, %arg5: f32, %arg6: f32): + ^bb0(%arg4: f32, %arg5: f32, %arg6: f32): %4 = arith.addf %arg4, %arg5 : f32 linalg.yield %4 : f32 } diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -296,8 +296,8 @@ %cf = arith.constant 1.0 : f32 %3 = memref.subview %arg0[%c0, %c0][%c2000, %c4000][%c1, %c1] : memref to memref - linalg.fill(%cf, %3) { __internal_linalg_transform__ = "_promote_views_aligned_"} - : f32, memref + linalg.fill { __internal_linalg_transform__ = "_promote_views_aligned_"} + ins(%cf : f32) outs(%3 : memref) return } // CHECK-LABEL: func @aligned_promote_fill @@ -306,9 +306,9 @@ // CHECK: %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<32000000xi8> // CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref // CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref to memref -// CHECK: linalg.fill({{.*}}, %[[v0]]) : f32, memref +// CHECK: linalg.fill ins({{.*}} : f32) outs(%[[v0]] : memref) // CHECK: memref.copy %[[s0]], %[[l0]] : memref to memref -// CHECK: linalg.fill(%[[cf]], %[[v0]]) : f32, memref +// CHECK: linalg.fill ins(%[[cf]] : f32) outs(%[[v0]] : memref) func @aligned_promote_fill_complex(%arg0: memref, offset: ?, strides: [?, 1]>) { %c2000 = arith.constant 2000 : index @@ -319,8 +319,8 @@ %cc = complex.create %cf, %cf : complex %3 = memref.subview %arg0[%c0, %c0][%c2000, %c4000][%c1, %c1] : memref, offset: ?, strides: [?, 1]> to memref, offset: ?, strides: [?, ?]> - linalg.fill(%cc, %3) { __internal_linalg_transform__ = "_promote_views_aligned_"} - : complex, memref, offset: ?, strides: [?, ?]> + linalg.fill { __internal_linalg_transform__ = "_promote_views_aligned_"} + ins(%cc : complex) outs(%3 : memref, offset: ?, strides: [?, ?]>) return } // CHECK-LABEL: func @aligned_promote_fill_complex @@ -329,9 +329,9 @@ // CHECK: %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<64000000xi8> // CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<64000000xi8> to memref> // CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref> to memref, #[[$STRIDED_2D_u_1]]> -// CHECK: linalg.fill({{.*}}, %[[v0]]) : complex, memref> +// CHECK: linalg.fill ins({{.*}} : complex) outs(%[[v0]] : memref>) // CHECK: memref.copy %[[s0]], %[[l0]] : memref, #map{{.*}}> to memref, #map{{.*}}> -// CHECK: linalg.fill(%[[cc]], %[[v0]]) : complex, memref> +// CHECK: linalg.fill ins(%[[cc]] : complex) outs(%[[v0]] : memref>) func @tile_permute_parallel_loop(%arg0: memref, %arg1: memref, diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -191,7 +191,7 @@ func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) { // CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32> // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> - linalg.fill(%arg0, %A) : f32, memref<8x16xf32> + linalg.fill ins(%arg0 : f32) outs(%A : memref<8x16xf32>) return } @@ -202,7 +202,7 @@ // CHECK-SAME: (%[[M:.*]]: memref, %[[val:.*]]: f32) // CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector // CHECK: vector.transfer_write %[[VEC]], %[[M]][] : vector, memref - linalg.fill(%arg0, %A) : f32, memref + linalg.fill ins(%arg0 : f32) outs(%A : memref) return } @@ -590,7 +590,7 @@ // CHECK: %[[V4:.*]] = arith.addi %[[DIM3]], %[[C3]] : index // CHECK: %[[V5:.*]] = arith.addi %[[V4]], %[[C2]] : index // CHECK: %[[INIT:.*]] = linalg.init_tensor [6, %[[V1]], %[[V2]], %[[V5]]] : tensor<6x?x?x?xf32> -// CHECK: %[[FILL:.*]] = linalg.fill(%{{.*}}, %[[INIT]]) : f32, tensor<6x?x?x?xf32> -> tensor<6x?x?x?xf32> +// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}} : f32) outs(%[[INIT]] : tensor<6x?x?x?xf32>) -> tensor<6x?x?x?xf32> // CHECK: %[[SRCDIM:.*]] = tensor.dim %[[SRC]], %[[C3]] : tensor<1x2x2x?xf32> // CHECK: %[[RESULT:.*]] = tensor.insert_slice %[[SRC]] into %[[FILL]][2, %[[LOW]], 3, 3] [1, 2, 2, %[[SRCDIM]]] [1, 1, 1, 1] : tensor<1x2x2x?xf32> into tensor<6x?x?x?xf32> // CHECK: return %[[RESULT]] @@ -833,7 +833,7 @@ // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %ident = arith.constant -3.40282e+38 : f32 %init = linalg.init_tensor [4] : tensor<4xf32> - %fill = linalg.fill(%ident, %init) : f32, tensor<4xf32> -> tensor<4xf32> + %fill = linalg.fill ins(%ident : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32> %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} @@ -858,7 +858,7 @@ // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %maxf32 = arith.constant 3.40282e+38 : f32 %init = linalg.init_tensor [4] : tensor<4xf32> - %fill = linalg.fill(%maxf32, %init) : f32, tensor<4xf32> -> tensor<4xf32> + %fill = linalg.fill ins(%maxf32 : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32> %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} @@ -881,7 +881,7 @@ // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %ident = arith.constant 1.0 : f32 %init = linalg.init_tensor [4] : tensor<4xf32> - %fill = linalg.fill(%ident, %init) : f32, tensor<4xf32> -> tensor<4xf32> + %fill = linalg.fill ins(%ident : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32> %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} @@ -904,7 +904,7 @@ // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> %ident = arith.constant false %init = linalg.init_tensor [4] : tensor<4xi1> - %fill = linalg.fill(%ident, %init) : i1, tensor<4xi1> -> tensor<4xi1> + %fill = linalg.fill ins(%ident : i1) outs(%init : tensor<4xi1>) -> tensor<4xi1> %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} @@ -927,7 +927,7 @@ // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> %ident = arith.constant true %init = linalg.init_tensor [4] : tensor<4xi1> - %fill = linalg.fill(%ident, %init) : i1, tensor<4xi1> -> tensor<4xi1> + %fill = linalg.fill ins(%ident : i1) outs(%init : tensor<4xi1>) -> tensor<4xi1> %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} @@ -950,7 +950,7 @@ // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> %ident = arith.constant false %init = linalg.init_tensor [4] : tensor<4xi1> - %fill = linalg.fill(%ident, %init) : i1, tensor<4xi1> -> tensor<4xi1> + %fill = linalg.fill ins(%ident : i1) outs(%init : tensor<4xi1>) -> tensor<4xi1> %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} @@ -974,7 +974,7 @@ // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32> %c0 = arith.constant 0.0 : f32 %init = linalg.init_tensor [4, 4] : tensor<4x4xf32> - %fill = linalg.fill(%c0, %init) : f32, tensor<4x4xf32> -> tensor<4x4xf32> + %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<4x4xf32>) -> tensor<4x4xf32> %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, 0)>, affine_map<(d0, d1) -> (d0, d1)>], @@ -1003,7 +1003,7 @@ // CHECK: vector.transfer_write {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<4xf32> %c0 = arith.constant 0.0 : f32 %init = linalg.init_tensor [4] : tensor<4xf32> - %fill = linalg.fill(%c0, %init) : f32, tensor<4xf32> -> tensor<4xf32> + %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32> %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, 0)>, affine_map<(d0, d1) -> (d0)>], @@ -1034,7 +1034,7 @@ // CHECK: %[[f:.*]] = vector.transfer_write %[[vF0]], %[[init]][] // CHECK-SAME: : vector, tensor - %1 = linalg.fill(%f0, %0) : f32, tensor -> tensor + %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor) -> tensor // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] // CHECK-SAME: : tensor<32xf32>, vector<32xf32> // CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -446,8 +446,8 @@ // %[[V:.*]] = memref.alloca(%[[S]]) : memref // %[[F:.*]] = memref.alloca(%[[S]]) : memref // %[[A:.*]] = memref.alloca(%[[S]]) : memref -// linalg.fill(%{{.*}}, %[[V]]) : f64, memref -// linalg.fill(%{{.*}}, %[[F]]) : i1, memref +// linalg.fill ins(%{{.*}} : f64) outs(%[[V]] : memref) +// linalg.fill ins(%{{.*}} : i1) outs(%[[F]] : memref) // CHECK: return func @sparse_expansion() { %c = arith.constant 8 : index diff --git a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir --- a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir @@ -35,7 +35,7 @@ // CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<1xindex> to memref // CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref // CHECK-DAG: %[[M:.*]] = memref.alloc() : memref<13xi32> -// CHECK-DAG: linalg.fill(%[[zeroI32]], %[[M]]) : i32, memref<13xi32> +// CHECK-DAG: linalg.fill ins(%[[zeroI32]] : i32) outs(%[[M]] : memref<13xi32>) // CHECK: scf.while : () -> () { // CHECK: %[[Cond:.*]] = call @getNextI32(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr, memref, memref) -> i1 // CHECK: scf.condition(%[[Cond]]) @@ -74,7 +74,7 @@ // CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<1xindex> to memref // CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref // CHECK-DAG: %[[M:.*]] = memref.alloc(%[[SizeI0]]) : memref -// CHECK-DAG: linalg.fill(%[[zeroI32]], %[[M]]) : i32, memref +// CHECK-DAG: linalg.fill ins(%[[zeroI32]] : i32) outs(%[[M]] : memref) // CHECK: scf.while : () -> () { // CHECK: %[[Cond:.*]] = call @getNextI32(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr, memref, memref) -> i1 // CHECK: scf.condition(%[[Cond]]) @@ -117,7 +117,7 @@ // CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref // CHECK-DAG: %[[M:.*]] = memref.alloc() : memref<2x4xf64> // CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : f64, memref<2x4xf64> +// CHECK-DAG: linalg.fill ins(%[[E0]] : f64) outs(%[[M]] : memref<2x4xf64>) // CHECK: scf.while : () -> () { // CHECK: %[[Cond:.*]] = call @getNextF64(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr, memref, memref) -> i1 // CHECK: scf.condition(%[[Cond]]) @@ -161,7 +161,7 @@ // CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref // CHECK-DAG: %[[M:.*]] = memref.alloc(%[[SizeI0]]) : memref // CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : f64, memref +// CHECK-DAG: linalg.fill ins(%[[E0]] : f64) outs(%[[M]] : memref) // CHECK: scf.while : () -> () { // CHECK: %[[Cond:.*]] = call @getNextF64(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr, memref, memref) -> i1 // CHECK: scf.condition(%[[Cond]]) @@ -205,7 +205,7 @@ // CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref // CHECK-DAG: %[[M:.*]] = memref.alloc(%[[SizeI1]]) : memref<2x?xf64> // CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : f64, memref<2x?xf64> +// CHECK-DAG: linalg.fill ins(%[[E0]] : f64) outs(%[[M]] : memref<2x?xf64>) // CHECK: scf.while : () -> () { // CHECK: %[[Cond:.*]] = call @getNextF64(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr, memref, memref) -> i1 // CHECK: scf.condition(%[[Cond]]) @@ -249,7 +249,7 @@ // CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref // CHECK-DAG: %[[M:.*]] = memref.alloc(%[[SizeI0]], %[[SizeI1]]) : memref // CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : f64, memref +// CHECK-DAG: linalg.fill ins(%[[E0]] : f64) outs(%[[M]] : memref) // CHECK: scf.while : () -> () { // CHECK: %[[Cond:.*]] = call @getNextF64(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr, memref, memref) -> i1 // CHECK: scf.condition(%[[Cond]]) @@ -297,7 +297,7 @@ // CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref // CHECK-DAG: %[[M:.*]] = memref.alloc() : memref<2x3x4xf64> // CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : f64, memref<2x3x4xf64> +// CHECK-DAG: linalg.fill ins(%[[E0]] : f64) outs(%[[M]] : memref<2x3x4xf64>) // CHECK: scf.while : () -> () { // CHECK: %[[Cond:.*]] = call @getNextF64(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr, memref, memref) -> i1 // CHECK: scf.condition(%[[Cond]]) diff --git a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir @@ -52,7 +52,7 @@ // CHECK: %[[VAL_5:.*]] = arith.constant 1 : index // CHECK: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref // CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<32xf32> -// CHECK: linalg.fill(%[[VAL_3]], %[[VAL_7]]) : f32, memref<32xf32> +// CHECK: linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_7]] : memref<32xf32>) // CHECK: scf.for %[[VAL_8:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_5]] { // CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_8]]] : memref // CHECK: %[[VAL_10:.*]] = arith.addf %[[VAL_9]], %[[VAL_1]] : f32 diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir @@ -75,7 +75,7 @@ // LINALG: scf.yield %[[A]], %[[i]], %[[j]] : memref, index, index // LINALG: } else { // slow path, fill tmp alloc and yield a memref_casted version of it - // LINALG: linalg.fill(%cst, %[[alloc]]) : f32, memref<4x8xf32> + // LINALG: linalg.fill ins(%cst : f32) outs(%[[alloc]] : memref<4x8xf32>) // LINALG: %[[d0:.*]] = memref.dim %[[A]], %[[c0]] : memref // LINALG: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[d0]], %[[i]], %[[c4]]) // LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]]) @@ -168,7 +168,7 @@ // LINALG-SAME: memref, index, index // LINALG: } else { // slow path, fill tmp alloc and yield a memref_casted version of it - // LINALG: linalg.fill(%cst, %[[alloc]]) : f32, memref<4x8xf32> + // LINALG: linalg.fill ins(%cst : f32) outs(%[[alloc]] : memref<4x8xf32>) // LINALG: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[c7]], %[[i]], %[[c4]]) // LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]]) // LINALG: %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1] diff --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir @@ -68,8 +68,8 @@ %RHS10 = memref.alloc() {alignment = 64} : memref<1x10xf32> %DST10 = memref.alloc() {alignment = 64} : memref<1x10xf32> - linalg.fill(%f1, %LHS10) : f32, memref<1x10xf32> - linalg.fill(%f1, %RHS10) : f32, memref<1x10xf32> + linalg.fill ins(%f1 : f32) outs(%LHS10 : memref<1x10xf32>) + linalg.fill ins(%f1 : f32) outs(%RHS10 : memref<1x10xf32>) %LHS = memref.cast %LHS10 : memref<1x10xf32> to memref %RHS = memref.cast %RHS10 : memref<1x10xf32> to memref diff --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir @@ -90,8 +90,8 @@ %RHS10 = memref.alloc() {alignment = 64} : memref<1x10xf32> %DST10 = memref.alloc() {alignment = 64} : memref<1x10xf32> - linalg.fill(%f1, %LHS10) : f32, memref<1x10xf32> - linalg.fill(%f1, %RHS10) : f32, memref<1x10xf32> + linalg.fill ins(%f1 : f32) outs(%LHS10 : memref<1x10xf32>) + linalg.fill ins(%f1 : f32) outs(%RHS10 : memref<1x10xf32>) %LHS = memref.cast %LHS10 : memref<1x10xf32> to memref %RHS = memref.cast %RHS10 : memref<1x10xf32> to memref diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir @@ -59,9 +59,9 @@ %B = memref.alloc() : !row_major_B %C = memref.alloc() : !row_major_C - linalg.fill(%v1, %A) : !elem_type_a, !row_major_A - linalg.fill(%v1, %B) : !elem_type_b, !row_major_B - linalg.fill(%v0, %C) : !elem_type_c, !row_major_C + linalg.fill ins(%v1 : !elem_type_a) outs(%A : !row_major_A) + linalg.fill ins(%v1 : !elem_type_b) outs(%B : !row_major_B) + linalg.fill ins(%v0 : !elem_type_c) outs(%C : !row_major_C) %c0 = arith.constant 0: index %c1 = arith.constant 1: index @@ -71,7 +71,7 @@ /// Preheating run: scf.for %arg0 = %c0 to %iters step %c1 { %z = arith.constant 0.0 : !elem_type_c - linalg.fill(%z, %C) : !elem_type_c, !row_major_C + linalg.fill ins(%z : !elem_type_c) outs(%C : !row_major_C) call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> () } %t_start_matmul = call @rtclock() : () -> f64 @@ -81,7 +81,7 @@ // Once linalg on tensors is ready, fusing fill at the register level will // be easy. %z = arith.constant 0.0 : !elem_type_c - linalg.fill(%z, %C) : !elem_type_c, !row_major_C + linalg.fill ins(%z : !elem_type_c) outs(%C : !row_major_C) call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> () } %t_end_matmul = call @rtclock() : () -> f64 @@ -90,7 +90,7 @@ // CHECK: {{^0$}} %C_ref = memref.alloc() : !row_major_C - linalg.fill(%v0, %C_ref) : !elem_type_c, !row_major_C + linalg.fill ins(%v0 : !elem_type_c) outs(%C_ref : !row_major_C) linalg.matmul ins(%A, %B : !row_major_A, !row_major_B) outs(%C_ref: !row_major_C) %act = memref.cast %C : !row_major_C to memref<*xf32> diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/matmul-vs-matvec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/matmul-vs-matvec.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/matmul-vs-matvec.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/matmul-vs-matvec.mlir @@ -12,7 +12,7 @@ %x = memref.dim %A, %c0 : memref %y = memref.dim %B, %c1 : memref %C = memref.alloc(%x, %y) : memref - linalg.fill(%f0, %C) : f32, memref + linalg.fill ins(%f0 : f32) outs(%C : memref) linalg.matmul ins(%A, %B: memref, memref) outs(%C: memref) return %C : memref @@ -26,7 +26,7 @@ %x = memref.dim %A, %c1 : memref %n = memref.dim %B, %c1 : memref %C = memref.alloc(%m, %n) : memref - linalg.fill(%f0, %C) : f32, memref + linalg.fill ins(%f0 : f32) outs(%C : memref) scf.for %i = %c0 to %n step %c1 { %b = memref.subview %B[0, %i][%x, 1][1, 1] : memref to memref %c = memref.subview %C[0, %i][%m, 1][1, 1] : memref to memref @@ -46,8 +46,8 @@ %val2 = arith.constant 17.0 : f32 %A = memref.alloc(%m, %x) : memref %B = memref.alloc(%x, %n) : memref - linalg.fill(%val1, %A) : f32, memref - linalg.fill(%val2, %B) : f32, memref + linalg.fill ins(%val1 : f32) outs(%A : memref) + linalg.fill ins(%val2 : f32) outs(%B : memref) memref.store %val1, %B[%c0, %c0] : memref %C1 = call @matmul(%A, %B) : (memref, memref) -> memref %C2 = call @matvec(%A, %B) : (memref, memref) -> memref diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir @@ -14,7 +14,7 @@ %cst = arith.constant 0.000000e+00 : f32 %c2 = arith.constant 2 : index %c0 = arith.constant 0 : index - %0 = linalg.fill(%cst, %arg2) : f32, tensor -> tensor + %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor) -> tensor %1 = affine.apply #map0(%c0, %c64)[%c2] %2 = linalg.init_tensor [%1, 2] : tensor %3 = scf.for %arg3 = %c0 to %c64 step %c2 iter_args(%arg4 = %2) -> (tensor) { @@ -83,9 +83,9 @@ %A = linalg.init_tensor [64] : tensor<64xf32> %B = linalg.init_tensor [64] : tensor<64xf32> %C = linalg.init_tensor [] : tensor - %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32> - %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32> - %CC = linalg.fill(%v0, %C) : f32, tensor -> tensor + %AA = linalg.fill ins(%v1 : f32) outs(%A : tensor<64xf32>) -> tensor<64xf32> + %BB = linalg.fill ins(%v2 : f32) outs(%B : tensor<64xf32>) -> tensor<64xf32> + %CC = linalg.fill ins(%v0 : f32) outs(%C : tensor) -> tensor %res = call @init_and_dot(%AA, %BB, %CC) : (tensor<64xf32>, tensor<64xf32>, tensor) -> tensor diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-call.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-call.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-call.mlir @@ -14,7 +14,7 @@ // Creates and returns a 1-D buffer of size %s1 filled with the value %f func @alloc_1d_filled_f32(%s1 : index, %f : f32) -> memref { %buf = memref.alloc(%s1) : memref - linalg.fill(%f, %buf) : f32, memref + linalg.fill ins(%f : f32) outs(%buf : memref) return %buf : memref } diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-nwc-wcf-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-nwc-wcf-call.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-nwc-wcf-call.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-nwc-wcf-call.mlir @@ -14,7 +14,7 @@ // Creates and returns 3-D buffer of size (%s1, %s2, %s3) filled with the value %f func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> memref { %buf = memref.alloc(%s1, %s2, %s3) : memref - linalg.fill(%f, %buf) : f32, memref + linalg.fill ins(%f : f32) outs(%buf : memref) return %buf : memref } diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-call.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-call.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-call.mlir @@ -14,7 +14,7 @@ // Creates and returns a 2-D buffer of size (%s1, %s2) filled with the value %f func @alloc_2d_filled_f32(%s1 : index, %s2 : index, %f : f32) -> memref { %buf = memref.alloc(%s1, %s2) : memref - linalg.fill(%f, %buf) : f32, memref + linalg.fill ins(%f : f32) outs(%buf : memref) return %buf : memref } diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nhwc-hwcf-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nhwc-hwcf-call.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nhwc-hwcf-call.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nhwc-hwcf-call.mlir @@ -14,7 +14,7 @@ // Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f func @alloc_4d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %f : f32) -> memref { %buf = memref.alloc(%s1, %s2, %s3, %s4) : memref - linalg.fill(%f, %buf) : f32, memref + linalg.fill ins(%f : f32) outs(%buf : memref) return %buf : memref } diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-call.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-call.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-call.mlir @@ -14,7 +14,7 @@ // Creates and returns 3-D buffer of size (%s1, %s2, %s3) filled with the value %f func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> memref { %buf = memref.alloc(%s1, %s2, %s3) : memref - linalg.fill(%f, %buf) : f32, memref + linalg.fill ins(%f : f32) outs(%buf : memref) return %buf : memref } diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-ndhwc-dhwcf-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-ndhwc-dhwcf-call.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-ndhwc-dhwcf-call.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-ndhwc-dhwcf-call.mlir @@ -14,7 +14,7 @@ // Creates and returns 5-D buffer of size (%s1, %s2, %s3, %s4, %s5) filled with the value %f func @alloc_5d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %s5 : index, %f : f32) -> memref { %buf = memref.alloc(%s1, %s2, %s3, %s4, %s5) : memref - linalg.fill(%f, %buf) : f32, memref + linalg.fill ins(%f : f32) outs(%buf : memref) return %buf : memref } diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py @@ -491,7 +491,7 @@ ir_type = _mlir_type_from_taco_type(self.dst_dtype) tensor = linalg.InitTensorOp(self.dst_dims, ir_type).result zero = arith.ConstantOp(ir_type, 0.0) - return linalg.FillOp(output=tensor, value=zero).results[0] + return linalg.fill(zero, outs=[tensor]) # Initialize the sparse tensor. mlir_type = _mlir_tensor_type(self.dst_dtype, self.dst_dims, diff --git a/mlir/test/mlir-cpu-runner/async.mlir b/mlir/test/mlir-cpu-runner/async.mlir --- a/mlir/test/mlir-cpu-runner/async.mlir +++ b/mlir/test/mlir-cpu-runner/async.mlir @@ -19,7 +19,7 @@ %c4 = arith.constant 4.0 : f32 %A = memref.alloc() : memref<4xf32> - linalg.fill(%c0, %A) : f32, memref<4xf32> + linalg.fill ins(%c0 : f32) outs(%A : memref<4xf32>) // CHECK: [0, 0, 0, 0] %U = memref.cast %A : memref<4xf32> to memref<*xf32> diff --git a/mlir/test/mlir-cpu-runner/sgemm-naive-codegen.mlir b/mlir/test/mlir-cpu-runner/sgemm-naive-codegen.mlir --- a/mlir/test/mlir-cpu-runner/sgemm-naive-codegen.mlir +++ b/mlir/test/mlir-cpu-runner/sgemm-naive-codegen.mlir @@ -7,14 +7,14 @@ %cf1 = arith.constant 1.00000e+00 : f32 - linalg.fill(%cf1, %A) : f32, memref<16x16xf32> - linalg.fill(%cf1, %B) : f32, memref<16x16xf32> + linalg.fill ins(%cf1 : f32) outs(%A : memref<16x16xf32>) + linalg.fill ins(%cf1 : f32) outs(%B : memref<16x16xf32>) %reps = arith.constant 1 : index %t_start = call @rtclock() : () -> f64 affine.for %arg0 = 0 to 5 { - linalg.fill(%cf1, %C) : f32, memref<16x16xf32> + linalg.fill ins(%cf1 : f32) outs(%C : memref<16x16xf32>) call @sgemm_naive(%A, %B, %C) : (memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>) -> () } %t_end = call @rtclock() : () -> f64 diff --git a/mlir/test/mlir-cpu-runner/unranked-memref.mlir b/mlir/test/mlir-cpu-runner/unranked-memref.mlir --- a/mlir/test/mlir-cpu-runner/unranked-memref.mlir +++ b/mlir/test/mlir-cpu-runner/unranked-memref.mlir @@ -42,18 +42,18 @@ %f10 = arith.constant 10.00000e+00 : f32 %V = memref.cast %A : memref<10x3xf32, 0> to memref - linalg.fill(%f10, %V) : f32, memref + linalg.fill ins(%f10 : f32) outs(%V : memref) %U = memref.cast %A : memref<10x3xf32, 0> to memref<*xf32> call @print_memref_f32(%U) : (memref<*xf32>) -> () %V2 = memref.cast %U : memref<*xf32> to memref - linalg.fill(%f5, %V2) : f32, memref + linalg.fill ins(%f5 : f32) outs(%V2 : memref) %U2 = memref.cast %V2 : memref to memref<*xf32> call @print_memref_f32(%U2) : (memref<*xf32>) -> () %V3 = memref.cast %V2 : memref to memref<*xf32> %V4 = memref.cast %V3 : memref<*xf32> to memref - linalg.fill(%f2, %V4) : f32, memref + linalg.fill ins(%f2 : f32) outs(%V4 : memref) %U3 = memref.cast %V2 : memref to memref<*xf32> call @print_memref_f32(%U3) : (memref<*xf32>) -> () @@ -79,7 +79,7 @@ func @return_two_var_memref_caller() { %0 = memref.alloca() : memref<4x3xf32> %c0f32 = arith.constant 1.0 : f32 - linalg.fill(%c0f32, %0) : f32, memref<4x3xf32> + linalg.fill ins(%c0f32 : f32) outs(%0 : memref<4x3xf32>) %1:2 = call @return_two_var_memref(%0) : (memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) call @print_memref_f32(%1#0) : (memref<*xf32>) -> () call @print_memref_f32(%1#1) : (memref<*xf32>) -> () @@ -94,7 +94,7 @@ func @return_var_memref_caller() { %0 = memref.alloca() : memref<4x3xf32> %c0f32 = arith.constant 1.0 : f32 - linalg.fill(%c0f32, %0) : f32, memref<4x3xf32> + linalg.fill ins(%c0f32 : f32) outs(%0 : memref<4x3xf32>) %1 = call @return_var_memref(%0) : (memref<4x3xf32>) -> memref<*xf32> call @print_memref_f32(%1) : (memref<*xf32>) -> () return diff --git a/mlir/test/mlir-cpu-runner/utils.mlir b/mlir/test/mlir-cpu-runner/utils.mlir --- a/mlir/test/mlir-cpu-runner/utils.mlir +++ b/mlir/test/mlir-cpu-runner/utils.mlir @@ -19,7 +19,7 @@ %f = arith.constant 2.00000e+00 : f32 %A = memref.alloc() : memref<16xf32> %B = memref.cast %A: memref<16xf32> to memref - linalg.fill(%f, %B) : f32, memref + linalg.fill ins(%f : f32) outs(%B : memref) %U = memref.cast %B : memref to memref<*xf32> call @print_memref_f32(%U): (memref<*xf32>) -> () memref.dealloc %A : memref<16xf32> @@ -33,7 +33,7 @@ %f4 = arith.constant 4.00000e+00 : f32 %A = memref.alloc() : memref<3x4x5xf32> %B = memref.cast %A: memref<3x4x5xf32> to memref - linalg.fill(%f, %B) : f32, memref + linalg.fill ins(%f : f32) outs(%B : memref) %c2 = arith.constant 2 : index memref.store %f4, %B[%c2, %c2, %c2]: memref diff --git a/mlir/test/mlir-opt/async.mlir b/mlir/test/mlir-opt/async.mlir --- a/mlir/test/mlir-opt/async.mlir +++ b/mlir/test/mlir-opt/async.mlir @@ -20,7 +20,7 @@ %c4 = arith.constant 4.0 : f32 %A = memref.alloc() : memref<4xf32> - linalg.fill(%c0, %A) : f32, memref<4xf32> + linalg.fill ins(%c0 : f32) outs(%A : memref<4xf32>) %U = memref.cast %A : memref<4xf32> to memref<*xf32> call @print_memref_f32(%U): (memref<*xf32>) -> () diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -65,22 +65,22 @@ # CHECK-LABEL: func @fill_tensor # CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<12x?xf32> # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32 - # CHECK-NEXT: %[[RES:.*]] = linalg.fill(%[[CST]], %[[OUT]]) : f32, tensor<12x?xf32> -> tensor<12x?xf32> + # CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32> # CHECK-NEXT: return %[[RES]] : tensor<12x?xf32> @builtin.FuncOp.from_py_func(RankedTensorType.get((12, -1), f32)) def fill_tensor(out): zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result - return linalg.FillOp(output=out, value=zero).result + return linalg.fill(zero, outs=[out]) # CHECK-LABEL: func @fill_buffer # CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32> # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32 - # CHECK-NEXT: linalg.fill(%[[CST]], %[[OUT]]) : f32, memref<12x?xf32> + # CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>) # CHECK-NEXT: return @builtin.FuncOp.from_py_func(MemRefType.get((12, -1), f32)) def fill_buffer(out): zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result - linalg.FillOp(output=out, value=zero) + linalg.fill(zero, outs=[out]) print(module) @@ -179,9 +179,9 @@ def pass_an_op_directly(arg0, arg1): one = arith.ConstantOp(F32Type.get(), 1.0) # CHECK: %[[LHS:.*]] = linalg.fill - lhs = linalg.FillOp(arg0, one) + lhs = linalg.fill(one, outs=[arg0]) # CHECK: %[[RHS:.*]] = linalg.fill - rhs = linalg.FillOp(arg1, one) + rhs = linalg.fill(one, outs=[arg1]) # CHECK: %[[INIT:.*]] = linalg.init_tensor init = linalg.InitTensorOp([4, 8], f32) # CHECK: linalg.matmul diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -29,10 +29,10 @@ %rhs = memref.alloc() : memref<4x8xf32> %O0 = memref.alloc() : memref<4x8xf32> %O1 = memref.alloc() : memref<4x8xf32> - linalg.fill(%v1, %lhs) : f32, memref - linalg.fill(%v2, %rhs) : f32, memref<4x8xf32> - linalg.fill(%v0, %O0) : f32, memref<4x8xf32> - linalg.fill(%v0, %O1) : f32, memref<4x8xf32> + linalg.fill ins(%v1 : f32) outs(%lhs : memref) + linalg.fill ins(%v2 : f32) outs(%rhs : memref<4x8xf32>) + linalg.fill ins(%v0 : f32) outs(%O0 : memref<4x8xf32>) + linalg.fill ins(%v0 : f32) outs(%O1 : memref<4x8xf32>) call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) : (memref, memref<4x8xf32>, memref<4x8xf32>) -> () @@ -60,10 +60,10 @@ %B = memref.alloc() : memref<16x8xf32> %C0 = memref.alloc() : memref<4x8xf32> %C1 = memref.alloc() : memref<4x8xf32> - linalg.fill(%v1, %A) : i8, memref<4x16xi8> - linalg.fill(%v2, %B) : f32, memref<16x8xf32> - linalg.fill(%v0, %C0) : f32, memref<4x8xf32> - linalg.fill(%v0, %C1) : f32, memref<4x8xf32> + linalg.fill ins(%v1 : i8) outs(%A : memref<4x16xi8>) + linalg.fill ins(%v2 : f32) outs(%B : memref<16x8xf32>) + linalg.fill ins(%v0 : f32) outs(%C0 : memref<4x8xf32>) + linalg.fill ins(%v0 : f32) outs(%C1 : memref<4x8xf32>) call @matmul_signed_on_buffers(%A, %B, %C0) : (memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> () @@ -137,9 +137,9 @@ %input = memref.alloc() : memref<1x4x16x1xf64> %filter = memref.alloc() : memref<2x2x1xf64> %output = memref.alloc() : memref<1x2x4x1xi32> - linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64> - linalg.fill(%v2, %filter) : f64, memref<2x2x1xf64> - linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32> + linalg.fill ins(%v1 : f64) outs(%input : memref<1x4x16x1xf64>) + linalg.fill ins(%v2 : f64) outs(%filter : memref<2x2x1xf64>) + linalg.fill ins(%v0 : i32) outs(%output : memref<1x2x4x1xi32>) call @conv_on_buffers(%input, %filter, %output) : (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> () @@ -163,9 +163,9 @@ %input = memref.alloc() : memref<1x4x16x1xf64> %shape = memref.alloc() : memref<2x2xf64> %output = memref.alloc() : memref<1x2x4x1xi32> - linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64> - linalg.fill(%v1, %shape) : f64, memref<2x2xf64> - linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32> + linalg.fill ins(%v1 : f64) outs(%input : memref<1x4x16x1xf64>) + linalg.fill ins(%v1 : f64) outs(%shape : memref<2x2xf64>) + linalg.fill ins(%v0 : i32) outs(%output : memref<1x2x4x1xi32>) %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -368,15 +368,15 @@ @builtin.FuncOp.from_py_func(f32, MemRefType.get([], i32)) def fill_0d_on_buffers(value, out): - linalg.fill_tensor(value, outs=[out]) + linalg.fill(value, outs=[out]) @builtin.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) def fill_1d_on_buffers(value, out): - linalg.fill_tensor(value, outs=[out]) + linalg.fill(value, outs=[out]) @builtin.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) def fill_2d_on_buffers(value, out): - linalg.fill_tensor(value, outs=[out]) + linalg.fill(value, outs=[out]) execution_engine = ExecutionEngine(transform(module, fill_boiler)) @@ -403,15 +403,15 @@ @builtin.FuncOp.from_py_func(f32, MemRefType.get([], i32)) def fill_0d_on_buffers(value, out): - linalg.fill_tensor(value, outs=[out], emit_generic=True) + linalg.fill(value, outs=[out], emit_generic=True) @builtin.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) def fill_1d_on_buffers(value, out): - linalg.fill_tensor(value, outs=[out], emit_generic=True) + linalg.fill(value, outs=[out], emit_generic=True) @builtin.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) def fill_2d_on_buffers(value, out): - linalg.fill_tensor(value, outs=[out], emit_generic=True) + linalg.fill(value, outs=[out], emit_generic=True) execution_engine = ExecutionEngine(transform(module, fill_boiler))