diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -24,18 +24,14 @@ static Value sourceMaterializationCallback(OpBuilder &builder, Type type, ValueRange inputs, Location loc) { assert(inputs.size() == 1); - if (inputs[0].getType().isa()) + auto inputType = inputs[0].getType(); + if (inputType.isa()) return nullptr; // A detensored value is converted back by creating a new tensor from its // element(s). - auto createNewTensorOp = - builder.create(loc, inputs[0]); - - // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to - // a tensor instead. - return builder.create( - loc, type, createNewTensorOp, ArrayRef{}); + return builder.create( + loc, RankedTensorType::get({}, inputType), inputs[0]); } namespace { @@ -161,39 +157,6 @@ } }; -/// Canonicalizes the pattern of the form -/// -/// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> -/// %reshaped_tensor = tensor.collapse_shape %tensor [] -/// : tensor<1xi32> into tensor -/// %extracted_element = tensor.extract %reshaped_tensor[] : tensor -/// -/// to just %element. -struct ExtractFromReshapeFromElements - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ExtractOp extract, - PatternRewriter &rewriter) const final { - if (!extract.indices().empty()) - return failure(); - - auto tensorReshape = - extract.tensor().getDefiningOp(); - if (tensorReshape == nullptr) - return failure(); - - auto tensorFromElements = - tensorReshape.getOperand() - .getDefiningOp(); - if (tensorFromElements == nullptr) - return failure(); - - rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); - return success(); - } -}; - /// @see LinalgDetensorize in Linalg/Passes.td for more details. struct LinalgDetensorize : public LinalgDetensorizeBase { LinalgDetensorize() = default; @@ -591,7 +554,7 @@ signalPassFailure(); RewritePatternSet canonPatterns(context); - canonPatterns.add(context); + tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(canonPatterns)))) signalPassFailure(); diff --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir --- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir @@ -19,8 +19,7 @@ // CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] // CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]] // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]] -// CHECK: %[[reshaped_tensor_res:.*]] = tensor.collapse_shape %[[new_tensor_res]] -// CHECK: return %[[reshaped_tensor_res]] +// CHECK: return %[[new_tensor_res]] func @detensor_op_sequence(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { %0 = linalg.init_tensor [] : tensor @@ -60,8 +59,7 @@ // CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]] // CHECK: %[[detensored_res3:.*]] = arith.divf %[[detensored_res]], %[[detensored_res2]] // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]] -// CHECK: %[[reshaped_tensor_res:.*]] = tensor.collapse_shape %[[new_tensor_res]] -// CHECK: return %[[reshaped_tensor_res]] +// CHECK: return %[[new_tensor_res]] func @detensor_multiple_ops(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { %0 = linalg.init_tensor [] : tensor @@ -82,8 +80,7 @@ // CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]] // CHECK: %[[detensored_res2:.*]] = arith.mulf %[[detensored_res]], %[[arg2_val]] // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res2]] -// CHECK: %[[reshaped_tensor_res:.*]] = tensor.collapse_shape %[[new_tensor_res]] -// CHECK: return %[[reshaped_tensor_res]] +// CHECK: return %[[new_tensor_res]] func @detensor_foreign_op(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { %0 = linalg.init_tensor [] : tensor @@ -102,5 +99,4 @@ // CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] // CHECK: %[[detensored_res:.*]] = "foreign.do_something"(%[[arg1_val]], %[[arg2_val]]) // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]] -// CHECK: %[[reshaped_tensor_res:.*]] = tensor.collapse_shape %[[new_tensor_res]] -// CHECK: return %[[reshaped_tensor_res]] +// CHECK: return %[[new_tensor_res]] diff --git a/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir b/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir --- a/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir @@ -2,17 +2,14 @@ // TODO: Detensoring breaks if %arg0 or %arg1 are passed directly as tensors. Fix that. func @if_true_test(%arg0: i1, %arg1: i32) -> tensor attributes {} { - %arg0_t = tensor.from_elements %arg0 : tensor<1xi1> - %arg0_t2 = tensor.collapse_shape %arg0_t [] : tensor<1xi1> into tensor - - %arg1_t = tensor.from_elements %arg1 : tensor<1xi32> - %arg1_t2 = tensor.collapse_shape %arg1_t [] : tensor<1xi32> into tensor + %arg0_t = tensor.from_elements %arg0 : tensor + %arg1_t = tensor.from_elements %arg1 : tensor %cst = arith.constant dense<10> : tensor %2 = linalg.init_tensor [] : tensor %3 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} - ins(%arg0_t2 : tensor) + ins(%arg0_t : tensor) outs(%2 : tensor) { ^bb0(%arg2: i1, %arg3: i8): // no predecessors %10 = arith.extui %arg2 : i1 to i8 @@ -20,12 +17,12 @@ } -> tensor %4 = tensor.extract %3[] : tensor %5 = arith.trunci %4 : i8 to i1 - cond_br %5, ^bb1, ^bb2(%arg1_t2 : tensor) + cond_br %5, ^bb1, ^bb2(%arg1_t : tensor) ^bb1: %6 = linalg.init_tensor [] : tensor %7 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} - ins(%arg1_t2, %cst : tensor, tensor) + ins(%arg1_t, %cst : tensor, tensor) outs(%6 : tensor) { ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors %10 = arith.addi %arg2, %arg3 : i32 @@ -44,6 +41,5 @@ // CHECK-NEXT: %[[add_res:.*]] = arith.addi // CHECK-NEXT: br ^[[bb2]](%[[add_res]] : i32) // CHECK-NEXT: ^[[bb2]] -// CHECK-NEXT: tensor.from_elements -// CHECK-NEXT: %[[func_res:.*]] = tensor.collapse_shape +// CHECK-NEXT: %[[func_res:.*]] = tensor.from_elements // CHECK-NEXT: return %[[func_res]] diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir --- a/mlir/test/Dialect/Linalg/detensorize_if.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir @@ -9,17 +9,15 @@ func @main() -> (tensor) attributes {} { %c0 = arith.constant 0 : i32 - %0 = tensor.from_elements %c0 : tensor<1xi32> - %reshaped0 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor + %0 = tensor.from_elements %c0 : tensor %c10 = arith.constant 10 : i32 - %1 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped1 = tensor.collapse_shape %1 [] : tensor<1xi32> into tensor - br ^bb1(%reshaped0 : tensor) + %1 = tensor.from_elements %c10 : tensor + br ^bb1(%0 : tensor) ^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 %3 = linalg.init_tensor [] : tensor %4 = linalg.generic #attrs - ins(%2, %reshaped1 : tensor, tensor) + ins(%2, %1 : tensor, tensor) outs(%3 : tensor) { ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors %8 = arith.cmpi slt, %arg0, %arg1 : i32 @@ -54,8 +52,7 @@ // CHECK-NEXT: arith.addi %{{.*}}, %{{.*}} // CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : i32) // CHECK-NEXT: ^[[bb3]](%{{.*}}: i32) -// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32> -// CHECK-NEXT: tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor +// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor // CHECK-NEXT: return %{{.*}} // CHECK-NEXT: } @@ -73,17 +70,15 @@ func @main() -> (tensor) attributes {} { %c0 = arith.constant 0 : i32 - %0 = tensor.from_elements %c0 : tensor<1xi32> - %reshaped0 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor + %0 = tensor.from_elements %c0 : tensor %c10 = arith.constant 10 : i32 - %1 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped1 = tensor.collapse_shape %1 [] : tensor<1xi32> into tensor - br ^bb1(%reshaped0 : tensor) + %1 = tensor.from_elements %c10 : tensor + br ^bb1(%0 : tensor) ^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 %3 = linalg.init_tensor [] : tensor %4 = linalg.generic #attrs - ins(%2, %reshaped1 : tensor, tensor) + ins(%2, %1 : tensor, tensor) outs(%3 : tensor) { ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors %8 = arith.cmpi slt, %arg0, %arg1 : i32 @@ -123,8 +118,7 @@ // CHECK-NEXT: ^[[bb3]](%{{.*}}: i32) // CHECK-NEXT: br ^[[bb4:.*]](%{{.*}} : i32) // CHECK-NEXT: ^[[bb4]](%{{.*}}: i32) -// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32> -// CHECK-NEXT: tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor +// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor // CHECK-NEXT: return %{{.*}} // CHECK-NEXT: } @@ -139,17 +133,15 @@ func @main() -> (tensor) attributes {} { %c0 = arith.constant 0 : i32 - %0 = tensor.from_elements %c0 : tensor<1xi32> - %reshaped0 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor + %0 = tensor.from_elements %c0 : tensor %c10 = arith.constant 10 : i32 - %1 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped1 = tensor.collapse_shape %1 [] : tensor<1xi32> into tensor - br ^bb1(%reshaped0 : tensor) + %1 = tensor.from_elements %c10 : tensor + br ^bb1(%0 : tensor) ^bb1(%2: tensor): // 2 preds: ^bb0, ^bb2 %3 = linalg.init_tensor [] : tensor %4 = linalg.generic #attrs - ins(%2, %reshaped1 : tensor, tensor) + ins(%2, %1 : tensor, tensor) outs(%3 : tensor) { ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors %8 = arith.cmpi slt, %arg0, %arg1 : i32 @@ -163,11 +155,10 @@ cond_br %5, ^bb2(%2 : tensor), ^bb2(%2 : tensor) ^bb2(%6: tensor): // pred: ^bb1 - %12 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped12 = tensor.collapse_shape %12 [] : tensor<1xi32> into tensor + %12 = tensor.from_elements %c10 : tensor %7 = linalg.init_tensor [] : tensor %8 = linalg.generic #attrs - ins(%6, %reshaped12 : tensor, tensor) + ins(%6, %12 : tensor, tensor) outs(%7 : tensor) { ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors %9 = arith.addi %arg0, %arg1 : i32 @@ -190,7 +181,6 @@ // CHECK-NEXT: arith.addi %{{.*}}, %{{.*}} // CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : i32) // CHECK-NEXT: ^[[bb3]](%{{.*}}: i32) -// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32> -// CHECK-NEXT: tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor +// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor // CHECK-NEXT: return %{{.*}} // CHECK-NEXT: } diff --git a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir --- a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir @@ -11,11 +11,10 @@ func @main(%farg0 : tensor) -> (tensor) attributes {} { %c10 = arith.constant 10 : i32 - %1 = tensor.from_elements %c10 : tensor<1xi32> - %reshaped1 = tensor.collapse_shape %1 [] : tensor<1xi32> into tensor + %1 = tensor.from_elements %c10 : tensor %3 = linalg.init_tensor [] : tensor %4 = linalg.generic #attrs - ins(%farg0, %reshaped1 : tensor, tensor) + ins(%farg0, %1 : tensor, tensor) outs(%3 : tensor) { ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): %8 = arith.cmpi slt, %arg0, %arg1 : i32 @@ -30,7 +29,6 @@ // DET-ALL-NEXT: tensor.extract %{{.*}}[] // DET-ALL-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}} // DET-ALL-NEXT: tensor.from_elements %{{.*}} -// DET-ALL-NEXT: tensor.collapse_shape %{{.*}} // DET-ALL-NEXT: return %{{.*}} : tensor // DET-ALL-NEXT: } diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir --- a/mlir/test/Dialect/Linalg/detensorize_while.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir @@ -52,7 +52,6 @@ // DET-ALL: br ^[[bb1]](%{{.*}} : i32) // DET-ALL: ^[[bb3]](%{{.*}}: i32) // DET-ALL: tensor.from_elements {{.*}} -// DET-ALL: tensor.collapse_shape {{.*}} // DET-ALL: return %{{.*}} : tensor // Test detensoring only ops involed in control-flow. @@ -68,6 +67,5 @@ // DET-CF: arith.addi {{.*}} // DET-CF: br ^[[bb1]](%{{.*}} : i32) // DET-CF: ^[[bb3]](%{{.*}}: i32) -// DET-CF: tensor.from_elements %{{.*}} : tensor<1xi32> -// DET-CF: tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor +// DET-CF: tensor.from_elements %{{.*}} : tensor // DET-CF: return %{{.*}} : tensor diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir --- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir @@ -76,8 +76,7 @@ // DET-ALL: cmpi slt, %{{.*}}, %{{.*}} : i32 // DET-ALL: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32) // DET-ALL: ^[[bb2]](%{{.*}}: i32) -// DET-ALL: tensor.from_elements %{{.*}} : tensor<1xi32> -// DET-ALL: tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor +// DET-ALL: tensor.from_elements %{{.*}} : tensor // DET-ALL: linalg.init_tensor [10] : tensor<10xi32> // DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor) outs(%{{.*}} : tensor<10xi32>) { // DET-ALL: ^bb0(%{{.*}}: i32, %{{.*}}: i32): @@ -85,8 +84,7 @@ // DET-ALL: } -> tensor<10xi32> // DET-ALL: br ^[[bb1]](%{{.*}} : tensor<10xi32>) // DET-ALL: ^[[bb3]](%{{.*}}: i32) -// DET-ALL: tensor.from_elements %{{.*}} : tensor<1xi32> -// DET-ALL: tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor +// DET-ALL: tensor.from_elements %{{.*}} : tensor // DET-ALL: return %{{.*}} : tensor // DET-ALL: }