diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -439,20 +439,6 @@ const BufferizationOptions &options; }; -/// Bufferize all ops in the given region. -LogicalResult bufferize(RewriterBase &rewriter, Region *region, - BufferizationState &state); - -/// Bufferize all ops in the given block. -LogicalResult bufferize(RewriterBase &rewriter, Block *block, - BufferizationState &state); - -/// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this -/// function returns immediately. Otherwise, it calls the `bufferize` interface -/// method of `BufferizableOpInterface`. -LogicalResult bufferize(RewriterBase &rewriter, Operation *op, - BufferizationState &state); - /// Return a contiguous MemRefType (i.e. with canonical/empty layout map) /// with the same shape as `shapedType` and specified `layout` and /// `addressSpace`. @@ -524,17 +510,7 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { - auto isaTensor = [](Type t) { return t.isa(); }; - if (any_of(op->getOperandTypes(), isaTensor) || - any_of(op->getResultTypes(), isaTensor)) - if (!state.getOptions().allowUnknownOps) - return op->emitError() << "unsupported op with tensors"; - - for (Region ®ion : op->getRegions()) - if (failed(comprehensive_bufferize::bufferize(rewriter, ®ion, state))) - return failure(); - - return success(); + return failure(); } bool isAllocationHoistingBarrier(Operation *op) const { return true; } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -453,59 +453,6 @@ rewriter.eraseOp(op); } -LogicalResult mlir::linalg::comprehensive_bufferize::bufferize( - RewriterBase &rewriter, Region *region, BufferizationState &state) { - for (Block &block : *region) - if (failed(bufferize(rewriter, &block, state))) - return failure(); - return success(); -} - -LogicalResult mlir::linalg::comprehensive_bufferize::bufferize( - RewriterBase &rewriter, Block *block, BufferizationState &state) { - // Ops may get deleted during the traversal, so do not iterate over `block` - // directly. - SmallVector ops; - ops.reserve(block->getOperations().size()); - for (Operation &op : *block) - ops.push_back(&op); - for (Operation *op : ops) - if (failed(bufferize(rewriter, op, state))) - return failure(); - return success(); -} - -LogicalResult mlir::linalg::comprehensive_bufferize::bufferize( - RewriterBase &rewriter, Operation *op, BufferizationState &state) { - // Check if op has tensor results or operands. - auto isaTensor = [](Type t) { return t.isa(); }; - bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); - bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); - bool hasRegions = !op->getRegions().empty(); - - // No tensor results/operands or regions. We are done. - if (!hasTensorResult && !hasTensorOperand && !hasRegions) - return success(); - - // Bufferize using `BufferizableOpInterface`. Interface implementations are - // responsible for bufferizing nested ops. - if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) { - rewriter.setInsertionPoint(op); - return bufferizableOp.bufferize(rewriter, state); - } - - // `op` is an unbufferizable tensor op. - if (!state.getOptions().allowUnknownOps) - return op->emitError() << "unsupported op with tensors"; - - // Bufferize all regions. - for (Region ®ion : op->getRegions()) - if (failed(bufferize(rewriter, ®ion, state))) - return failure(); - - return success(); -} - //===----------------------------------------------------------------------===// // Bufferization-specific scoped alloc/dealloc insertion support. //===----------------------------------------------------------------------===// @@ -656,28 +603,15 @@ if (auto toTensorOp = tensor.getDefiningOp()) return toTensorOp.memref(); - if (!isFunctionArgument(tensor)) { - if (static_cast(options.dynCastBufferizableOp(tensor))) { - // Dump tensor for easier debugging. - tensor.dump(); - llvm_unreachable("op is known, but has not been bufferized yet"); - return Value(); - } - if (!options.allowUnknownOps) { - // Dump tensor for easier debugging. - tensor.dump(); - // Note: An assertion should already have failed earlier. - llvm_unreachable("unknown ops are not allowed"); - return Value(); - } - } - // Insert to_memref op. OpBuilder::InsertionGuard g(rewriter); setInsertionPointAfter(rewriter, tensor); - return rewriter.create( - tensor.getLoc(), - getDynamicMemRefType(tensor.getType().cast()), tensor); + Type memrefType = + tensor.getType().isa() + ? getDynamicMemRefType(tensor.getType().cast()) + : getContiguousOrUnrankedMemRefType(tensor.getType()); + return rewriter.create(tensor.getLoc(), memrefType, + tensor); } bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace( diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp @@ -54,21 +54,18 @@ BufferizationState &state) const { auto toMemrefOp = cast(op); - // Fold to_memref(to_tensor(x)) to x. + // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. if (auto toTensorOp = toMemrefOp.tensor().getDefiningOp()) { - rewriter.replaceOp(toMemrefOp, toTensorOp.memref()); + Value buffer = toTensorOp.memref(); + if (toTensorOp.memref().getType() != toMemrefOp.getType()) + buffer = rewriter.create(toMemrefOp.getLoc(), buffer, + toMemrefOp.getType()); + rewriter.replaceOp(toMemrefOp, buffer); return success(); } - // If a ToMemrefOp's tensor operand has not been bufferized yet, the op - // remains unchanged. All IR up to this ToMemrefOp has already been - // bufferized, unless there were unknown ops that could be bufferized. - assert((isFunctionArgument(toMemrefOp.tensor()) || - state.getOptions().allowUnknownOps) && - "expected that tensor is mapped"); - - return success(); + return failure(); } }; @@ -87,7 +84,7 @@ bufferization::ToTensorOp> { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { - return success(); + return failure(); } bool isWritable(Operation *op, Value value, BufferizationState &state) const { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -115,6 +115,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" @@ -550,6 +551,13 @@ return success(); } +/// Return true if the given op has a tensor result or a tensor operand. +static bool hasTensorSemantics(Operation *op) { + bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); + bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); + return hasTensorResult || hasTensorOperand; +} + /// Analyze all ops that are contained in `op`. static LogicalResult inPlaceAnalysis(Operation *op, BufferizationAliasInfo &aliasInfo, @@ -560,8 +568,7 @@ SmallVector ops; op->walk([&](Operation *op) { // No tensors => no buffers. - if (none_of(op->getOperandTypes(), isaTensor) && - none_of(op->getResultTypes(), isaTensor)) + if (!hasTensorSemantics(op)) return; ops.push_back(op); }); @@ -658,6 +665,65 @@ return runComprehensiveBufferize(op, *options, state); } +/// Rewrite pattern that bufferizes bufferizable ops. +// TODO: Match only BufferizableOpInterface. This does not work with external +// models at the moment. +struct BufferizationPattern : public RewritePattern { + BufferizationPattern(MLIRContext *context, BufferizationState &state, + PatternBenefit benefit = 1) + : RewritePattern(MatchAnyOpTypeTag(), benefit, context), state(state) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // No tensors => no buffers. + if (!hasTensorSemantics(op)) + return failure(); + auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op); + if (!bufferizableOp) + return failure(); + return bufferizableOp.bufferize(rewriter, state); + } + +private: + BufferizationState &state; +}; + +/// Check the result of bufferization. Return an error if an op was not +/// bufferized, unless partial bufferization is allowed. +static LogicalResult +checkBufferizationResult(Operation *op, const BufferizationOptions &options) { + if (!options.allowUnknownOps) { + // Check if all ops were bufferized. + LogicalResult status = success(); + op->walk([&](Operation *op) { + if (!hasTensorSemantics(op)) + return WalkResult::advance(); + + // Bufferization dialect ops will canonicalize away if all other ops are + // bufferized. + if (isa(op)) + return WalkResult::advance(); + + // Ops that are not in the allow list can be ignored. + if (!options.isOpAllowed(op)) + return WalkResult::advance(); + + // Ops without any uses and no side effects will fold away. + if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) + return WalkResult::advance(); + + status = op->emitError("op was not bufferized"); + return WalkResult::interrupt(); + }); + + if (failed(status)) + return status; + } + + return success(); +} + LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( Operation *op, const BufferizationOptions &options, BufferizationState &state) { @@ -693,8 +759,10 @@ } // Bufferize the op and its nested ops. - if (failed(bufferize(rewriter, op, state))) + OwningRewritePatternList patterns(op->getContext()); + patterns.add(op->getContext(), state); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) return failure(); - return success(); + return checkBufferizationResult(op, options); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -64,14 +64,12 @@ // Set insertion point now that potential alloc/dealloc are introduced. rewriter.setInsertionPoint(op); - auto bufferizedOp = cast(op.clone( - rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); + op.clone(rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands); // Replace the results of the old op with the new output buffers. state.replaceOp(rewriter, op, newOutputBuffers); - return comprehensive_bufferize::bufferize(rewriter, bufferizedOp.getBlock(), - state); + return success(); } /// Linalg OpResults usually bufferize inplace with their tied (output @@ -309,7 +307,7 @@ for (auto it : llvm::zip(oldRegionInOutArgs, newRegionInOutArgs)) { Value oldArg = std::get<0>(it); Value newArg = std::get<1>(it); - rewriter.setInsertionPointToStart(newTiledLoopOp->getBlock()); + rewriter.setInsertionPointToStart(newTiledLoopOp.getBody()); if (oldArg.getType().isa()) { newBlockArgs.push_back(rewriter.create( oldArg.getLoc(), newArg)); @@ -323,17 +321,21 @@ newBlockArgs); // Replace previous terminator with a new one that does not yield anything. - Operation *oldTerminator = newTiledLoopOp.getBody()->getTerminator(); + auto oldTerminator = + cast(newTiledLoopOp.getBody()->getTerminator()); rewriter.setInsertionPointToEnd(newTiledLoopOp.getBody()); + for (Value val : oldTerminator.values()) + if (val.getType().isa()) + // Make sure that yielded values do not DCE away. + rewriter.create( + oldTerminator->getLoc(), val); rewriter.create(oldTerminator->getLoc()); rewriter.eraseOp(oldTerminator); // Replace results and delete old op. state.replaceOp(rewriter, op, newResults); - // Bufferize loop body. - return comprehensive_bufferize::bufferize(rewriter, - newTiledLoopOp.getBody(), state); + return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -665,20 +665,12 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { +#ifndef NDEBUG auto returnOp = cast(op); assert(isa(returnOp->getParentOp()) && "only support FuncOp parent for ReturnOp"); - - for (OpOperand &operand : returnOp->getOpOperands()) { - auto tensorType = operand.get().getType().dyn_cast(); - if (!tensorType) - continue; - Value v = state.lookupBuffer(rewriter, operand.get()); - Value returnTensor = - rewriter.create(returnOp.getLoc(), v); - operand.set(returnTensor); - } - return success(); +#endif // NDEBUG + return failure(); } }; @@ -686,10 +678,7 @@ : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { - auto funcOp = cast(op); - - // Bufferize function body. - return comprehensive_bufferize::bufferize(rewriter, &funcOp.body(), state); + return failure(); } /// Return `true` if the given function argument is writable. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -64,14 +64,13 @@ BufferizationState &state) const { // TODO: Add bufferization support when needed. scf.execute_region should be // bufferized similar to scf.if. - auto executeRegionOp = cast(op); + // auto executeRegionOp = cast(op); bool hasTensorReturnType = any_of( op->getResultTypes(), [](Type t) { return t.isa(); }); if (hasTensorReturnType) return op->emitError( "scf.execute_region with tensor result not supported"); - return comprehensive_bufferize::bufferize( - rewriter, &executeRegionOp.getRegion(), state); + return success(); } BufferRelation bufferRelation(Operation *op, OpResult opResult, @@ -196,14 +195,6 @@ // Replace op results. state.replaceOp(rewriter, op, newIfOp->getResults()); - // Bufferize then/else blocks. - if (failed(comprehensive_bufferize::bufferize(rewriter, newIfOp.thenBlock(), - state))) - return failure(); - if (failed(comprehensive_bufferize::bufferize(rewriter, newIfOp.elseBlock(), - state))) - return failure(); - return success(); } @@ -337,10 +328,6 @@ // Replace loop results. state.replaceOp(rewriter, op, newForOp->getResults()); - // Bufferize loop body. - if (failed(comprehensive_bufferize::bufferize(rewriter, loopBody, state))) - return failure(); - return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -444,6 +445,10 @@ // Copy tensor. Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source()); state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView); + } else { + // Make sure that `source` does not DCE away. + rewriter.create( + loc, insertSliceOp.source()); } state.replaceOp(rewriter, op, dstMemref); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir @@ -31,23 +31,23 @@ %v1 = arith.constant 1.0 : f32 %v2 = arith.constant 2.0 : f32 - // CHECK-NEXT: %[[C:.*]] = memref.alloca() {alignment = 128 : i64} : memref - // CHECK-NEXT: %[[B:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32> // CHECK-NEXT: %[[A:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32> + // CHECK-NEXT: %[[B:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32> + // CHECK-NEXT: %[[C:.*]] = memref.alloca() {alignment = 128 : i64} : memref %A = linalg.init_tensor [64] : tensor<64xf32> %B = linalg.init_tensor [64] : tensor<64xf32> %C = linalg.init_tensor [] : tensor // CHECK-NEXT: linalg.fill(%[[C1]], %[[A]]) : f32, memref<64xf32> + // CHECK-NEXT: %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]> // CHECK-NEXT: linalg.fill(%[[C2]], %[[B]]) : f32, memref<64xf32> + // CHECK-NEXT: %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]> // CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref + // CHECK-NEXT: %[[cC:.*]] = memref.cast %[[C]] : memref to memref %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32> %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32> %CC = linalg.fill(%v0, %C) : f32, tensor -> tensor - // CHECK-NEXT: %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]> - // CHECK-NEXT: %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]> - // CHECK-NEXT: %[[cC:.*]] = memref.cast %[[C]] : memref to memref // CHECK-NEXT: call @init_and_dot(%[[cA]], %[[cB]], %[[cC]]) %res = call @init_and_dot(%AA, %BB, %CC) : (tensor<64xf32>, tensor<64xf32>, tensor) -> tensor diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -142,7 +142,7 @@ func @unknown_op(%A : tensor<4xf32>) -> tensor<4xf32> { - // expected-error @+1 {{unsupported op with tensors}} + // expected-error @+1 {{op was not bufferized}} %r = "marklar"(%A) : (tensor<4xf32>) -> (tensor<4xf32>) return %r: tensor<4xf32> } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -168,9 +168,9 @@ -> (tensor, tensor, tensor, tensor) { // Hoisted allocs. - // CHECK: %[[REALLOC1:.*]] = memref.alloc - // CHECK: %[[REALLOC2:.*]] = memref.alloc // CHECK: %[[REALLOC3:.*]] = memref.alloc + // CHECK: %[[REALLOC2:.*]] = memref.alloc + // CHECK: %[[REALLOC1:.*]] = memref.alloc // Alloc and copy the whole result tensor. Copy the tensor.extract_slice. // CHECK: linalg.copy(%[[A0]], %[[REALLOC3]] @@ -516,23 +516,23 @@ %v1 = arith.constant 1.0 : f32 %v2 = arith.constant 2.0 : f32 - // CHECK-NEXT: %[[C:.*]] = memref.alloc() {alignment = 128 : i64} : memref - // CHECK-NEXT: %[[B:.*]] = memref.alloc() {alignment = 128 : i64} : memref<64xf32> // CHECK-NEXT: %[[A:.*]] = memref.alloc() {alignment = 128 : i64} : memref<64xf32> + // CHECK-NEXT: %[[B:.*]] = memref.alloc() {alignment = 128 : i64} : memref<64xf32> + // CHECK-NEXT: %[[C:.*]] = memref.alloc() {alignment = 128 : i64} : memref %A = linalg.init_tensor [64] : tensor<64xf32> %B = linalg.init_tensor [64] : tensor<64xf32> %C = linalg.init_tensor [] : tensor // CHECK-NEXT: linalg.fill(%[[C1]], %[[A]]) : f32, memref<64xf32> + // CHECK-NEXT: %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]> // CHECK-NEXT: linalg.fill(%[[C2]], %[[B]]) : f32, memref<64xf32> + // CHECK-NEXT: %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]> // CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref + // CHECK-NEXT: %[[cC:.*]] = memref.cast %[[C]] : memref to memref %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32> %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32> %CC = linalg.fill(%v0, %C) : f32, tensor -> tensor - // CHECK-NEXT: %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]> - // CHECK-NEXT: %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]> - // CHECK-NEXT: %[[cC:.*]] = memref.cast %[[C]] : memref to memref // CHECK-NEXT: call @init_and_dot(%[[cA]], %[[cB]], %[[cC]]) %res = call @init_and_dot(%AA, %BB, %CC) : (tensor<64xf32>, tensor<64xf32>, tensor) -> tensor @@ -727,6 +727,7 @@ tensor<256x192xf32> to tensor<256x16xf32> // %4 does not match an insert_slice, it cannot be bufferized inplace and needs to alloc. + // CHECK: %[[T:.*]] = memref.subview %[[C]][%[[I]], %[[J]]] [8, 16] [1, 1] %4 = tensor.extract_slice %C[%arg3, %arg5] [8, 16] [1, 1] : tensor<128x192xf32> to tensor<8x16xf32> @@ -752,7 +753,6 @@ // insert_slice is inplace but its source comes from an equivalent buffer // that is not in place. So we must insert a copy of the small buffer into // the bigger buffer. - // CHECK: %[[T:.*]] = memref.subview %[[C]][%[[I]], %[[J]]] [8, 16] [1, 1] // CHECK: linalg.copy(%[[ALLOC]], %[[T]]) %7 = tensor.insert_slice %6 into %arg6[%arg3, %arg5] [8, 16] [1, 1] : tensor<8x16xf32> into tensor<128x192xf32> @@ -823,7 +823,8 @@ // init_tensor itself does not alloc but forwards to the **second** // insert_slice. InitTensorOp replaces the init_tensor with an out-of-place // extract_slice. - // CHECK: %[[EXTRACT_SLICE_ALLOC:.*]] = memref.alloc(%[[sz]]) + // CHECK: %[[EXTRACT_SLICE_ALLOC:.*]] = memref.alloc(%[[sz]]) + // CHECK: %[[T_SUBVIEW:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] %a = linalg.init_tensor[%sz] : tensor // CHECK: linalg.fill({{.*}}, %[[EXTRACT_SLICE_ALLOC]]) : f32, memref @@ -834,7 +835,6 @@ // CHECK: linalg.copy(%[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]]) : memref, memref %r0 = tensor.insert_slice %f into %t[0][%sz][1]: tensor into tensor - // CHECK: %[[T_SUBVIEW:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] // CHECK: linalg.copy(%[[EXTRACT_SLICE_ALLOC]], %[[T_SUBVIEW]]) %r1 = tensor.insert_slice %f into %t[42][%sz][1]: tensor into tensor diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6436,6 +6436,7 @@ includes = ["include"], deps = [ ":BufferizableOpInterface", + ":BufferizationDialect", ":IR", ":MemRefDialect", ":Support", @@ -6719,6 +6720,7 @@ ":Pass", ":StandardOps", ":Support", + ":Transforms", "//llvm:Support", ], )