diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -328,7 +328,24 @@ /*defaultImplementation=*/[{ return success(); }] - > + >, + InterfaceMethod< + /*desc=*/[{ + Return the bufferized type of the given tensor block argument. The + block argument is guaranteed to belong to a block of this op. + }], + /*retType=*/"BaseMemRefType", + /*methodName=*/"getBufferType", + /*args=*/(ins "BlockArgument":$bbArg, + "const BufferizationOptions &":$options), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(bbArg.getOwner()->getParentOp() == $_op && + "bbArg must belong to this op"); + auto tensorType = bbArg.getType().cast(); + return bufferization::getMemRefType(tensorType, options); + }] + >, ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -482,8 +482,10 @@ Value bufferization::getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options) { +#ifndef NDEBUG auto tensorType = value.getType().dyn_cast(); assert(tensorType && "unexpected non-tensor type"); +#endif // NDEBUG // Replace "%t = to_tensor %m" with %m. if (auto toTensorOp = value.getDefiningOp()) @@ -492,7 +494,7 @@ // Insert to_memref op. OpBuilder::InsertionGuard g(rewriter); setInsertionPointAfter(rewriter, value); - Type memrefType = getMemRefType(tensorType, options); + Type memrefType = getBufferType(value, options); ensureToMemrefOpIsValid(value, memrefType); return rewriter.create(value.getLoc(), memrefType, value); @@ -507,6 +509,11 @@ if (auto toTensorOp = value.getDefiningOp()) return toTensorOp.getMemref().getType().cast(); + if (auto bbArg = value.dyn_cast()) + if (auto bufferizableOp = + options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp())) + return bufferizableOp.getBufferType(bbArg, options); + return getMemRefType(tensorType, options); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -393,9 +393,16 @@ // Otherwise, we have to use a memref type with a fully dynamic layout map to // avoid copies. We are currently missing patterns for layout maps to // canonicalize away (or canonicalize to more precise layouts). + // + // FuncOps must be bufferized before their bodies, so add them to the worklist + // first. SmallVector worklist; - op->walk([&](Operation *op) { - if (hasTensorSemantics(op)) + op->walk([&](func::FuncOp funcOp) { + if (hasTensorSemantics(funcOp)) + worklist.push_back(funcOp); + }); + op->walk([&](Operation *op) { + if (hasTensorSemantics(op) && !isa(op)) worklist.push_back(op); }); diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -725,7 +725,7 @@ // The result types of a WhileOp are the same as the "after" bbArg types. SmallVector argsTypesAfter = llvm::to_vector( llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { - return getBufferType(bbArg, options).cast(); + return bufferization::getBufferType(bbArg, options).cast(); })); // Construct a new scf.while op with memref instead of tensor values. @@ -1107,7 +1107,7 @@ LogicalResult bufferize(Operation *op, RewriterBase &b, const BufferizationOptions &options) const { // Will be bufferized as part of ForeachThreadOp. - return failure(); + return success(); } // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -154,7 +154,7 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { // Op is bufferized as part of AssumingOp. - return failure(); + return success(); } }; diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -313,10 +313,10 @@ // CHECK: %[[alloc2:.*]] = memref.alloc(%{{.*}}) // CHECK: memref.copy %[[iter2]], %[[alloc2]] // CHECK: memref.dealloc %[[iter2]] -// CHECK: %[[casted2:.*]] = memref.cast %[[alloc2]] // CHECK: %[[alloc1:.*]] = memref.alloc(%{{.*}}) // CHECK: memref.copy %[[iter1]], %[[alloc1]] // CHECK: memref.dealloc %[[iter1]] +// CHECK: %[[casted2:.*]] = memref.cast %[[alloc2]] // CHECK: %[[casted1:.*]] = memref.cast %[[alloc1]] // CHECK: %[[cloned1:.*]] = bufferization.clone %[[casted1]] // CHECK: memref.dealloc %[[alloc1]] @@ -384,10 +384,10 @@ // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1> // CHECK: memref.copy %[[w1]], %[[a1]] // CHECK: memref.dealloc %[[w1]] - // CHECK: %[[casted1:.*]] = memref.cast %[[a1]] // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1> // CHECK: memref.copy %[[w0]], %[[a0]] // CHECK: memref.dealloc %[[w0]] + // CHECK: %[[casted1:.*]] = memref.cast %[[a1]] // CHECK: %[[casted0:.*]] = memref.cast %[[a0]] // CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]] // CHECK: memref.dealloc %[[a0]] @@ -437,10 +437,10 @@ // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1> // CHECK: memref.copy %[[w1]], %[[a1]] // CHECK: memref.dealloc %[[w1]] - // CHECK: %[[casted1:.*]] = memref.cast %[[a1]] // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1> // CHECK: memref.copy %[[w0]], %[[a0]] // CHECK: memref.dealloc %[[w0]] + // CHECK: %[[casted1:.*]] = memref.cast %[[a1]] // CHECK: %[[casted0:.*]] = memref.cast %[[a0]] // CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]] // CHECK: memref.dealloc %[[a0]] @@ -457,9 +457,9 @@ // CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1> // CHECK: memref.copy %[[b1]], %[[a3]] // CHECK: memref.dealloc %[[b1]] - // CHECK: %[[casted3:.*]] = memref.cast %[[a3]] // CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1> // CHECK: memref.copy %[[b0]], %[[a2]] + // CHECK: %[[casted3:.*]] = memref.cast %[[a3]] // CHECK: %[[casted2:.*]] = memref.cast %[[a2]] // CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted2]] // CHECK: memref.dealloc %[[a2]]