diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -642,15 +642,15 @@ LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, PatternRewriter &rewriter) const override { - // see if tensor input of tensor.extract op is result of linalg.fill op + // See if tensor input of tensor.extract op is the result of a linalg.fill op. auto fillOp = extractOp.getTensor().getDefiningOp(); if (!fillOp) return failure(); - // get scalar input operand of linalg.fill + // Get scalar input operand of linalg.fill op. Value extractedScalar = fillOp.getInputs()[0]; - // replace tensor.extract op with op that simply produces the scalar + // Replace tensor.extract op with scalar value used to fill the tensor. rewriter.replaceOp(extractOp, extractedScalar); return success(); } diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -337,20 +337,19 @@ // ----- // CHECK-LABEL: func @fold_fill_extract() -func.func @fold_fill_extract() -> i1 { - // CHECK: %[[SCALAR:.+]] = arith.constant - %true = arith.constant true +// CHECK-SAME: %[[ARG0:.+]]: i1 +func.func @fold_fill_extract(%arg0 : i1) -> i1 { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - // CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[SCALAR]] + // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ARG0]]outs(%{{.+}}{{.*}} %empty_dynamic = tensor.empty(%c1) : tensor<1x2x3x?xi1> - %filled = linalg.fill ins(%true : i1) outs(%empty_dynamic : tensor<1x2x3x?xi1>) -> tensor<1x2x3x?xi1> + %filled = linalg.fill ins(%arg0 : i1) outs(%empty_dynamic : tensor<1x2x3x?xi1>) -> tensor<1x2x3x?xi1> // CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[FILL]] %extracted = tensor.extract %filled[%c0, %c0, %c0, %c0] : tensor<1x2x3x?xi1> - // CHECK: return %[[EXTRACTED]] + // CHECK: return %[[ARG0]] return %extracted : i1 }