diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -284,6 +284,19 @@ SmallVector lbs, ubs, steps; unpackRanges(loopRanges, lbs, ubs, steps); + auto dropNonShapedValues = + [](ArrayRef operands) -> SmallVector { + SmallVector filteredOperands; + for (OpOperand *operand : operands) { + Type type = operand->get().getType(); + if (type.isa()) + filteredOperands.push_back(operand->get()); + } + return filteredOperands; + }; + auto inputOperands = dropNonShapedValues(linalgOp.getInputOperands()); + auto outputOperands = dropNonShapedValues(linalgOp.getOutputOperands()); + auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange ivs, ValueRange inputs, ValueRange outputs) { @@ -292,9 +305,6 @@ bodyBuilderFn(nestedBuilder, nestedLoc, ivs, outputTensors); nestedBuilder.create(nestedLoc, results); }; - - SmallVector inputOperands = linalgOp.getInputOperands(); - SmallVector outputOperands = linalgOp.getOutputOperands(); auto tiledLoop = b.create(loc, lbs, ubs, steps, inputOperands, outputOperands, b.getArrayAttr(iteratorTypes), wrappedBuilderFn); diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -130,3 +130,17 @@ // TLOOP-SAME: ins (%{{.*}} = %[[ARG_0]]: [[TY]], %{{.*}} = %[[ARG_1]]: [[TY]]) // TLOOP-SAME: outs (%{{.*}} = %[[INIT]]: [[TY]]) // TLOOP-SAME: distribution["block_x", "block_y", "none"] { + + +func @fill(%arg0 : tensor) -> tensor { + %c0 = constant 0.0 : f32 + %0 = linalg.fill(%c0, %arg0) : f32, tensor -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @fill + +// TLOOP-LABEL: func @fill +// TLOOP-NOT: ins +// TLOOP: tensor.extract_slice +// TLOOP-NEXT: linalg.fill +// TLOOP-NEXT: tensor.insert_slice