diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -60,7 +60,7 @@ }); } -/// A conversion patttern for detensoring `linalg.generic` ops. +/// A conversion pattern for detensoring `linalg.generic` ops. class DetensorizeGenericOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -69,7 +69,7 @@ ConversionPatternRewriter &rewriter) const override { Block *originalBlock = op->getBlock(); - // Gather some information about the op before inling its region. + // Gather some information about the op before inlining its region. Block *opEntryBlock = &*op.getRegion().begin(); YieldOp yieldOp = dyn_cast(op.getRegion().back().getTerminator()); @@ -476,6 +476,18 @@ DenseSet blockArgsToDetensor; FunctionOpInterface funcOp = getOperation(); + // Make sure the entry block of the function doesn't contain any Linalg ops. + // Otherwise, it may lead to the signature of the block being changed by the + // dialect conversion below, which would make the function op invalid + // because its type shouldn't change. + IRRewriter rewriter(funcOp->getContext()); + Block *entryBlock = &funcOp.getFunctionBody().front(); + Block *postEntryBlock = + rewriter.splitBlock(entryBlock, entryBlock->begin()); + rewriter.setInsertionPointToStart(entryBlock); + auto branch = + rewriter.create(rewriter.getUnknownLoc(), postEntryBlock); + if (aggressiveMode.getValue()) { AggressiveDetensoringModel costModel; costModel.compute(funcOp, typeConverter, opsToDetensor, @@ -553,6 +565,11 @@ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(canonPatterns)))) signalPassFailure(); + + // Get rid of the dummy entry block we created in the beginning to work + // around dialect conversion signature rewriting. + rewriter.eraseOp(branch); + rewriter.mergeBlocks(postEntryBlock, entryBlock); } }; } // namespace diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline="builtin.module(func.func(linalg-detensorize))" | FileCheck %s + +#map = affine_map<() -> ()> +func.func @main(%arg0: tensor) -> tensor { + %0 = tensor.empty() : tensor + %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = []} ins(%arg0 : tensor) outs(%0 : tensor) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor + cf.br ^bb1(%1 : tensor) +^bb1(%2: tensor): // pred: ^bb0 + return %2 : tensor +} + +// CHECK-LABEL: @main +// CHECK-SAME: (%[[ARG0:.+]]: tensor) -> tensor +// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor +// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32) +// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32): +// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor +// CHECK: return %[[ELEMENTS]] : tensor diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir --- a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir @@ -44,8 +44,8 @@ } // CHECK-LABEL: func @main -// CHECK-NEXT: arith.constant 0 : i32 -// CHECK-NEXT: arith.constant 10 +// CHECK-DAG: arith.constant 0 : i32 +// CHECK-DAG: arith.constant 10 // CHECK-NEXT: cf.br ^[[bb1:.*]](%{{.*}} : i32) // CHECK-NEXT: ^[[bb1]](%{{.*}}: i32) // CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}}