diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -268,6 +268,9 @@ llvm::EquivalenceClasses equivalentInfo; }; +/// Return `true` if the given value is a BlockArgument of a FuncOp. +bool isFunctionArgument(Value value); + /// Determine which OpOperand* will alias with `result` if the op is bufferized /// in place. Return an empty vector if the op is not bufferizable. SmallVector getAliasingOpOperand(OpResult result); @@ -342,18 +345,18 @@ DialectBufferizationState(const DialectBufferizationState &) = delete; }; -/// BufferizationState keeps track of memory buffers and provides a variety of -/// helper functions for dealing with them. In particular, +/// BufferizationState provides a variety of helper functions for dealing with +/// tensor values and memref buffers. In particular, /// `BufferizableOpInterface::bufferize` implementation should utilize the /// following helper functions. /// /// * `createAlloc` / `createDealloc` / `createAllocDeallocPair` creates ops /// that allocate and/or deallocate memref buffers. -/// * `mapBuffer` maps a tensor value to a memref buffer during bufferization. -/// * `lookupBuffer` returns the mapped memref buffer of a given tensor value. +/// * `lookupBuffer` returns the memref buffer of a given tensor value. /// * `getResultBuffer` returns the memref buffer for a given tensor OpResult. /// Based on inplace bufferization decisions of the analysis, it may either /// directly return a mapped buffer or allocate a new brand new buffer. +/// * `replaceOp` replaces an op with new values. class BufferizationState { public: BufferizationState(Operation *op, const BufferizationOptions &options) @@ -378,16 +381,19 @@ /// Creates a memcpy between two given buffers. void createMemCpy(OpBuilder &b, Location loc, Value from, Value to); - /// Replace an op with replacement values. The op is deleted. + /// Replace an op with replacement values. The op is deleted. Tensor OpResults + /// must be replaced with memref values. void replaceOp(Operation *op, ValueRange values); - /// Map tensor values to memref buffers. - // TODO: Deprecated. Remove all uses of this op. Use `replaceOp` instead. - void mapBuffer(ValueRange tensors, ValueRange buffers); - - /// Map a tensor value to a memref buffer. - // TODO: Deprecated. Remove all uses of this op. Use `replaceOp` instead. - void mapBuffer(Value tensor, Value buffer); + /// Replace an op with a new op. Tensor OpResults must be replaced with memref + /// values. + template + OpTy replaceOpWithNewOp(OpBuilder &b, Operation *op, Args &&...args) { + Operation *newOp = + b.create(op->getLoc(), std::forward(args)...); + replaceOp(op, newOp->getResults()); + return cast(newOp); + } /// Lookup the memref buffer that is associated to the given tensor value. /// Asserts if no buffer is associated. @@ -396,23 +402,11 @@ /// Return `true` if the given OpResult has been decided to bufferize inplace. bool isInPlace(OpResult opResult) const; - /// Return `true` if the given value is mapped. - // TODO: Deprecated. Remove all uses of this op. - bool isMapped(Value value) const; - /// Return the result buffer (memref) for a given OpResult (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. Value getResultBuffer(OpResult result); - /// Mark `op` as obsolete, so that it is deleted after bufferization. - // TODO: Deprecated. Remove all uses of this op. - void markOpObsolete(Operation *op); - - /// Erase all ops that were marked obsolete. - // TODO: Deprecated. Remove all uses of this op. - void eraseObsoleteOps(); - /// Return dialect-specific bufferization state. template StateT &getDialectState(StringRef name) { // Create state if it does not exist yet. @@ -441,12 +435,6 @@ /// functions and `runComprehensiveBufferize` may access this object. BufferizationAliasInfo aliasInfo; - /// The mapping of tensors to buffers. - BlockAndValueMapping mapping; - - /// Obsolete ops that should be deleted after bufferization. - SmallVector obsoleteOps; - /// Dialect-specific bufferization state. DenseMap> dialectState; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp @@ -35,10 +35,8 @@ GlobalCreator globalCreator(moduleOp); auto globalMemref = globalCreator.getGlobalFor(constantOp); - Value memref = b.create( - constantOp.getLoc(), globalMemref.type(), globalMemref.getName()); - state.mapBuffer(constantOp, memref); - + state.replaceOpWithNewOp(b, op, globalMemref.type(), + globalMemref.getName()); return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -498,19 +498,6 @@ if (!state.getOptions().allowUnknownOps) return op->emitError() << "unsupported op with tensors"; - // Replace all OpOperands with "to-tensor casted" bufferized values. - for (OpOperand &operand : op->getOpOperands()) { - if (operand.get().getType().isa() && - state.isMapped(operand.get())) { - assert(state.getOptions().allowUnknownOps && - "unsupported op error should have been emitted earlier"); - b.setInsertionPoint(op); - Value toTensorOp = b.create( - op->getLoc(), state.lookupBuffer(operand.get())); - operand.set(toTensorOp); - } - } - // Bufferize all regions. for (Region ®ion : op->getRegions()) if (failed(bufferize(®ion, state))) @@ -654,38 +641,13 @@ // Bufferization-specific BlockAndValueMapping support with debugging. //===----------------------------------------------------------------------===// -/// Wrapper for better debugging. -void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer( - ValueRange tensors, ValueRange buffers) { - assert(!tensors.empty() && "unexpected empty tensors"); -#ifndef NDEBUG - for (Value tensor : tensors) { - assert(tensor && "unexpected empty tensor"); - assert(tensor.getType().isa() && "unexpected non-tensor type"); - } - for (Value buffer : buffers) { - assert(buffer && "unexpected empty buffer"); - assert((buffer.getType().isa() || - buffer.getType().isa()) && - "expected that tensor is mapped to memref"); - } -#endif // NDEBUG - return mapping.map(tensors, buffers); -} - -/// Wrapper for better debugging. -void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer( - Value tensor, Value buffer) { - assert(tensor && "unexpected empty tensor"); - assert(tensor.getType().isa() && "unexpected non-tensor type"); - assert(buffer && "unexpected empty buffer"); - assert((buffer.getType().isa() || - buffer.getType().isa()) && - "expected that tensor is mapped to memref"); - return mapping.map(tensor, buffer); +bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) { + auto bbArg = value.dyn_cast(); + if (!bbArg) + return false; + return isa(bbArg.getOwner()->getParentOp()); } -/// Wrapper for better debugging. Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer( Value tensor) { assert(tensor.getType().isa() && "unexpected non-tensor type"); @@ -694,37 +656,29 @@ if (auto toTensorOp = tensor.getDefiningOp()) return toTensorOp.memref(); - Value buffer = mapping.lookupOrNull(tensor); - if (!buffer) { - if (options.allowUnknownOps) { - // `tensor` was not bufferized yet. This should never happen with - // bufferizable ops. - assert(!options.dynCastBufferizableOp(tensor) && "tensor is not mapped"); - // Insert to_memref op. - OpBuilder b(tensor.getContext()); - setInsertionPointAfter(b, tensor); - return b.create( - tensor.getLoc(), - getDynamicMemRefType(tensor.getType().cast()), - tensor); + if (!isFunctionArgument(tensor)) { + if (static_cast(options.dynCastBufferizableOp(tensor))) { + // Dump tensor for easier debugging. + tensor.dump(); + llvm_unreachable("op is known, but has not been bufferized yet"); + return Value(); + } + if (!options.allowUnknownOps) { + // Dump tensor for easier debugging. + tensor.dump(); + // Note: An assertion should already have failed earlier. + llvm_unreachable("unknown ops are not allowed"); + return Value(); } - - // Dump tensor for easier debugging. - tensor.dump(); - llvm_unreachable("tensor is not mapped"); - return Value(); } - assert((buffer.getType().isa() || - buffer.getType().isa()) && - "expected that tensor is mapped to memref"); - return buffer; -} - -bool mlir::linalg::comprehensive_bufferize::BufferizationState::isMapped( - Value value) const { - assert(value.getType().isa() && "unexpected non-tensor type"); - return mapping.contains(value); + // Insert to_memref op. + OpBuilder &b = getBuilder(); + OpBuilder::InsertionGuard g(b); + setInsertionPointAfter(b, tensor); + return b.create( + tensor.getLoc(), + getDynamicMemRefType(tensor.getType().cast()), tensor); } bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace( @@ -732,18 +686,6 @@ return aliasInfo.isInPlace(opResult); } -void mlir::linalg::comprehensive_bufferize::BufferizationState::markOpObsolete( - Operation *op) { - obsoleteOps.push_back(op); -} - -void mlir::linalg::comprehensive_bufferize::BufferizationState:: - eraseObsoleteOps() { - for (Operation *op : obsoleteOps) - op->erase(); - obsoleteOps.clear(); -} - MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType( ShapedType shapedType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp @@ -63,21 +63,9 @@ // If a ToMemrefOp's tensor operand has not been bufferized yet, the op // remains unchanged. All IR up to this ToMemrefOp has already been // bufferized, unless there were unknown ops that could be bufferized. - if (!state.isMapped(toMemrefOp.tensor())) { - assert(state.getOptions().allowUnknownOps && - "expected that tensor is mapped"); - return success(); - } - - // If a ToMemrefOp's tensor operand has been bufferized, the op can be - // removed. - Value memref = state.lookupBuffer(toMemrefOp.tensor()); - // Do not replace a ToMemrefOp with itself. E.g., when bufferizing a - // function body, ToMemrefOps were inserted before starting bufferization of - // the function body. Such ToMemrefOps are replaced in a separate step after - // the function body has been bufferized. - if (toMemrefOp.getResult() != memref) - toMemrefOp.replaceAllUsesWith(memref); + assert((isFunctionArgument(toMemrefOp.tensor()) || + state.getOptions().allowUnknownOps) && + "expected that tensor is mapped"); return success(); } @@ -98,8 +86,6 @@ bufferization::ToTensorOp> { LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { - auto tensorLoadOp = cast(op); - state.mapBuffer(tensorLoadOp.result(), tensorLoadOp.memref()); return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -699,8 +699,5 @@ if (failed(bufferize(op, state))) return failure(); - // Erase all obsolete ops. - state.eraseObsoleteOps(); - return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -56,7 +56,6 @@ if (!resultBuffer) return failure(); newOutputBuffers.push_back(resultBuffer); - state.mapBuffer(opResult, resultBuffer); } // Clone the newly bufferized op. @@ -68,7 +67,9 @@ auto bufferizedOp = cast( op.clone(b, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); - // The original op will be DCE'd away later. + // Replace the results of the old op with the new output buffers. + state.replaceOp(op, newOutputBuffers); + return comprehensive_bufferize::bufferize(bufferizedOp.getBlock(), state); } @@ -194,7 +195,7 @@ Value alloc = state.createAllocDeallocPair(b, initTensorOp->getLoc(), initTensorOp.result()); - state.mapBuffer(initTensorOp.result(), alloc); + state.replaceOp(op, alloc); return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -551,9 +551,6 @@ .equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()]; Value oldRes = callOp->getResult(returnOperand.getOperandNumber()); Value buffer = state.lookupBuffer(callOp->getOperand(idx)); - // Add CallOp operand/result equivalence: this is interprocedural - // info. - state.mapBuffer(oldRes, buffer); // Add a ToTensorOp to kill all uses of the CallOp return. // Replace all uses of the CallOp results so we can erase the CallOp. // This ToTensorOp must fold/DCE away or bufferization should be @@ -561,8 +558,6 @@ Value toTensorOp = b.create(callOp.getLoc(), buffer); oldRes.replaceAllUsesWith(toTensorOp); - // Add new op equivalence info. - state.mapBuffer(toTensorOp, buffer); continue; } @@ -603,8 +598,6 @@ if (buffer.getType() != memRefType) { Value castBuffer = b.create(callOp.getLoc(), memRefType, buffer); - // Add new op equivalence info. - state.mapBuffer(tensorOperand, castBuffer); buffer = castBuffer; } newOperands.push_back(buffer); @@ -616,7 +609,7 @@ newCallOp->setAttrs(callOp->getAttrs()); // 5. Delete the op at the end of bufferization. - state.markOpObsolete(callOp); + callOp->erase(); return success(); } @@ -651,7 +644,6 @@ Value returnTensor = b.create( returnOp.getLoc(), v); operand.set(returnTensor); - state.mapBuffer(returnTensor, v); } return success(); } @@ -662,23 +654,6 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto funcOp = cast(op); - b.setInsertionPointToStart(&funcOp.body().front()); - - // Create BufferCastOps for function args. - for (auto bbArg : funcOp.getArguments()) { - auto tensorType = bbArg.getType().dyn_cast(); - if (!tensorType) - continue; - auto rankedTensorType = tensorType.dyn_cast(); - // Cast the tensor to the most dynamic buffer possible. Further - // canonicalizations will clean up. - Type memRefType = rankedTensorType - ? getDynamicMemRefType(rankedTensorType) - : getContiguousOrUnrankedMemRefType(tensorType); - Value bufferCast = b.create(funcOp.getLoc(), - memRefType, bbArg); - state.mapBuffer(bbArg, bufferCast); - } // Bufferize function body. return comprehensive_bufferize::bufferize(&funcOp.body(), state); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -78,9 +78,7 @@ : MemRefLayoutAttrInterface(); Type memRefType = getContiguousOrUnrankedMemRefType( castOp.getResult().getType(), layout, memorySpace); - Value res = - b.create(castOp.getLoc(), memRefType, resultBuffer); - state.mapBuffer(castOp.getResult(), res); + state.replaceOpWithNewOp(b, op, memRefType, resultBuffer); return success(); } }; @@ -103,11 +101,10 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto dimOp = cast(op); - if (dimOp.source().getType().isa()) { - Value v = state.lookupBuffer(dimOp.source()); - dimOp.result().replaceAllUsesWith( - b.create(dimOp.getLoc(), v, dimOp.index())); - } + if (!dimOp.source().getType().isa()) + return dimOp.emitError("unranked tensor not supported"); + Value v = state.lookupBuffer(dimOp.source()); + state.replaceOpWithNewOp(b, op, v, dimOp.index()); return success(); } }; @@ -168,7 +165,7 @@ subView = alloc; } - state.mapBuffer(extractSliceOp.result(), subView); + state.replaceOp(op, subView); return success(); } }; @@ -191,10 +188,9 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto extractOp = cast(op); - Location loc = extractOp.getLoc(); Value srcMemref = state.lookupBuffer(extractOp.tensor()); - Value l = b.create(loc, srcMemref, extractOp.indices()); - extractOp.replaceAllUsesWith(l); + state.replaceOpWithNewOp(b, op, srcMemref, + extractOp.indices()); return success(); } }; @@ -228,7 +224,7 @@ Value destMemref = state.getResultBuffer(insertOp->getOpResult(0)); b.create(loc, insertOp.scalar(), destMemref, insertOp.indices()); - state.mapBuffer(insertOp, destMemref); + state.replaceOp(op, destMemref); return success(); } @@ -423,7 +419,7 @@ state.createMemCpy(b, insertSliceOp.getLoc(), srcMemref, subView); } - state.mapBuffer(insertSliceOp.result(), dstMemref); + state.replaceOp(op, dstMemref); return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -38,13 +38,17 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { - auto transferReadOp = cast(op); - assert(transferReadOp.getShapedType().isa() && + auto readOp = cast(op); + assert(readOp.getShapedType().isa() && "only tensor types expected"); // TransferReadOp always reads from the bufferized op.source(). - Value v = state.lookupBuffer(transferReadOp.source()); - transferReadOp.sourceMutable().assign(v); + Value buffer = state.lookupBuffer(readOp.source()); + Value read = b.create( + readOp.getLoc(), readOp.getVectorType(), buffer, readOp.indices(), + readOp.permutation_map(), readOp.padding(), readOp.mask(), + readOp.in_boundsAttr()); + state.replaceOp(op, read); return success(); } }; @@ -90,7 +94,7 @@ b.create( writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(), writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); - state.mapBuffer(op->getResult(0), resultBuffer); + state.replaceOp(op, resultBuffer); return success(); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir @@ -30,11 +30,11 @@ // CHECK: %[[dim:.*]] = tensor.dim %[[A]] // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]] // CHECK: memref.copy %[[A_memref]], %[[casted]] // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor - // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]] // CHECK: return %[[res_tensor]] return %0 : tensor } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir @@ -52,10 +52,10 @@ -> (vector<5xf32>, vector<5xf32>) { %idx = arith.constant 0 : index %cst = arith.constant 0.0 : f32 + // CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]] // CHECK: %[[v1:.*]] = vector.transfer_read %[[m1]] %1 = vector.transfer_read %t1[%idx], %cst : tensor, vector<5xf32> - // CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]] // CHECK: %[[dummy:.*]] = "test.dummy_op"(%[[m1_tensor]]) %0 = "test.dummy_op"(%t1) : (tensor) -> tensor // CHECK: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]] @@ -114,11 +114,11 @@ func @unused_unknown_op(%t1 : tensor) -> vector<5xf32> { %idx = arith.constant 0 : index %cst = arith.constant 0.0 : f32 + // ToTensorOp is inserted to pass in the result of the above bufferized op. + // CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]] // CHECK: vector.transfer_read %[[m1]] %1 = vector.transfer_read %t1[%idx], %cst : tensor, vector<5xf32> - // ToTensorOp is inserted to pass in the result of the above bufferized op. - // CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]] // CHECK: "test.dummy_op"(%[[m1_tensor]]) "test.dummy_op"(%t1) : (tensor) -> () @@ -158,10 +158,10 @@ %c0 = arith.constant 0 : index // CHECK-TENSOR: %[[alloc:.*]] = memref.alloc // CHECK-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]] // CHECK-TENSOR: memref.copy %[[t1_memref]], %[[casted]] // CHECK-TENSOR: memref.store %{{.*}}, %[[alloc]] %0 = tensor.insert %f into %t1[%c0] : tensor - // CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]] // CHECK-TENSOR: return %[[casted_tensor]] return %0 : tensor } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -168,25 +168,25 @@ -> (tensor, tensor, tensor, tensor) { // Hoisted allocs. - // CHECK: %[[REALLOC_A0_2:.*]] = memref.alloc - // CHECK: %[[REALLOC_A0:.*]] = memref.alloc - // CHECK: %[[REALLOC_A1:.*]] = memref.alloc + // CHECK: %[[REALLOC1:.*]] = memref.alloc + // CHECK: %[[REALLOC2:.*]] = memref.alloc + // CHECK: %[[REALLOC3:.*]] = memref.alloc // Alloc and copy the whole result tensor. Copy the tensor.extract_slice. - // CHECK: linalg.copy(%[[A0]], %[[REALLOC_A0]] - // CHECK: %[[SV_A0:.*]] = memref.subview %[[REALLOC_A0]] + // CHECK: linalg.copy(%[[A0]], %[[REALLOC3]] + // CHECK: %[[SV_A0:.*]] = memref.subview %[[REALLOC3]] // CHECK: linalg.copy(%[[t0]], %[[SV_A0]]) %r0 = tensor.insert_slice %t0 into %A0[0][4][1] : tensor<4xf32> into tensor // Alloc and copy the whole result tensor. Copy the tensor.extract_slice. // CHECK: linalg.copy(%[[A0]] - // CHECK: %[[SV_A0_2:.*]] = memref.subview %[[REALLOC_A0_2]] + // CHECK: %[[SV_A0_2:.*]] = memref.subview %[[REALLOC2]] // CHECK: linalg.copy(%[[t1]], %[[SV_A0_2]]) %r1 = tensor.insert_slice %t1 into %A0[0][4][1] : tensor<4xf32> into tensor // Still alloc the large tensor because %A1 is read after. Copy the tensor.extract_slice. // CHECK: linalg.copy(%[[A1]] - // CHECK: %[[SV_A1:.*]] = memref.subview %[[REALLOC_A1]] + // CHECK: %[[SV_A1:.*]] = memref.subview %[[REALLOC1]] // CHECK: linalg.copy(%[[t0]], %[[SV_A1]]) %r2 = tensor.insert_slice %t0 into %A1[0][4][1] : tensor<4xf32> into tensor @@ -196,7 +196,7 @@ // CHECK: linalg.copy(%[[t1]], %[[SV_A1_2]]) %r3 = tensor.insert_slice %t1 into %A1[0][4][1] : tensor<4xf32> into tensor - // CHECK: return %[[REALLOC_A0]], %[[REALLOC_A0_2]], %[[REALLOC_A1]] : + // CHECK: return %[[REALLOC3]], %[[REALLOC2]], %[[REALLOC1]] : // CHECK-SAME: memref, memref, memref return %r0, %r1, %r2, %r3: tensor, tensor, tensor, tensor }