diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -644,16 +644,14 @@ PatternRewriter &rewriter) const override { // see if tensor input of tensor.extract op is result of linalg.fill op auto fillOp = extractOp.getTensor().getDefiningOp(); - if (!fillOp) { + if (!fillOp) return failure(); - } // get scalar input operand of linalg.fill Value extractedScalar = fillOp.getInputs()[0]; // replace tensor.extract op with op that simply produces the scalar - rewriter.replaceOpWithNewOp( - extractOp, extractedScalar.getType(), extractedScalar); + rewriter.replaceOp(extractOp, extractedScalar); return success(); } }; diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -335,6 +335,25 @@ return %1 : tensor } +// ----- +// CHECK-LABEL: func @fold_fill_extract() +func.func @fold_fill_extract() -> i1 { + // CHECK: %[[SCALAR:.+]] = arith.constant + %true = arith.constant true + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[SCALAR]] + %empty_dynamic = tensor.empty(%c1) : tensor<1x2x3x?xi1> + %filled = linalg.fill ins(%true : i1) outs(%empty_dynamic : tensor<1x2x3x?xi1>) -> tensor<1x2x3x?xi1> + + // CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[FILL]] + %extracted = tensor.extract %filled[%c0, %c0, %c0, %c0] : tensor<1x2x3x?xi1> + + // CHECK: return %[[EXTRACTED]] + return %extracted : i1 +} + // ----- // CHECK: func @fold_self_copy