diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -890,6 +890,27 @@ return b.createOperation(state); }] >, + InterfaceMethod< + /*desc=*/[{ + Clone the current operation with the given location, operands + and BlockAndValueMapping but leave the regions empty. This is + used to abstract away the optional underlying region creation. + This does not change the balance between input, output_buffer + and init_tensors operands. + }], + /*retTy=*/"Operation *", + /*methodName=*/"cloneWithoutRegions", + (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, + "ValueRange":$operands), + [{ + OperationState state( + loc, ConcreteOp::getOperationName(), operands, resultTypes, + $_op->getAttrs()); + for (size_t cnt = 0, e = $_op->getNumRegions(); cnt < e; ++cnt) + state.addRegion(); + return b.createOperation(state); + }] + >, StaticInterfaceMethod< /*desc=*/[{ Returns the region builder for constructing the body for linalg.generic. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -79,38 +79,17 @@ mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter, LinalgOp linalgOp, ValueRange inputs, ValueRange outputs) { - if (auto genericOp = mlir::dyn_cast(*linalgOp)) { - // Generate a new linalg operation that works on buffers. - auto newGenericOp = rewriter.create( - genericOp.getLoc(), - /*resultTensorTypes=*/llvm::None, - /*inputs=*/inputs, - /*outputs=*/outputs, genericOp.indexing_maps(), - genericOp.iterator_types(), genericOp.docAttr(), - genericOp.library_callAttr()); - - // Create a new block in the region of the new Generic Op. - Block *oldBlock = genericOp.getBody(); - Region &newRegion = newGenericOp.region(); - Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), - oldBlock->getArgumentTypes()); - - // Clone the body of the old block to the new block. - BlockAndValueMapping mapping; - mapping.map(oldBlock->getArguments(), newBlock->getArguments()); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToEnd(newBlock); - for (auto &op : oldBlock->getOperations()) { - Operation *clonedOp = rewriter.clone(op, mapping); - mapping.map(op.getResults(), clonedOp->getResults()); - } - return newGenericOp; - } SmallVector newOperands = inputs; newOperands.append(outputs.begin(), outputs.end()); - return linalgOp.clone(rewriter, linalgOp.getLoc(), - /*resultTypes=*/ArrayRef{}, newOperands); + auto newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(), + /*resultTypes=*/ArrayRef{}, + newOperands); + for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) { + auto &oldRegion = std::get<0>(regions); + auto &newRegion = std::get<1>(regions); + rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); + } + return newOp; } //===----------------------------------------------------------------------===// @@ -344,9 +323,8 @@ BufferizeTypeConverter typeConverter; // Mark all Standard operations legal. - target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalOp(); diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -316,3 +316,16 @@ // CHECK: vector.transfer_read {{.*}} : memref<4xf32>, vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, memref<4xf32> } + +// ----- + +// CHECK-LABEL: func @bufferize_dot +func @bufferize_dot(%in: tensor<4xf32>, %out: tensor) -> tensor { + %dot = linalg.dot ins(%in, %in : tensor<4xf32>, tensor<4xf32>) + outs(%out : tensor) -> tensor + return %dot : tensor + // CHECK: linalg.dot ins(%{{.*}}, %{{.*}} : memref<4xf32>, memref<4xf32>) + // CHECK-SAME: outs(%[[OUT:.*]] : memref) + // CHECK: %[[OUT_TENSOR:.*]] = memref.tensor_load %[[OUT]] : memref + // CHECK: return %[[OUT_TENSOR]] +}