diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -400,20 +400,19 @@ if (isa(op)) return success(); - // Check if op has tensor results or operands. - auto isaTensor = [](Type t) { return t.isa(); }; - bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); - bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); - if (!hasTensorResult && !hasTensorOperand) - return success(); - // Bufferize using `BufferizableOpInterface`. b.setInsertionPoint(op); if (auto bufferizableOp = dyn_cast(op)) return bufferizableOp.bufferize(b, state); - // Other op with tensors. No bufferization method specified. - return op->emitError() << "unsupported op with tensors"; + // Check if op has tensor results or operands. + auto isaTensor = [](Type t) { return t.isa(); }; + bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); + bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); + if (hasTensorResult && hasTensorOperand) + return op->emitError() << "unsupported op with tensors"; + + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -2454,14 +2454,14 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto transferReadOp = cast(op); + if (!transferReadOp.getShapedType().isa()) + return success(); // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); // TransferReadOp always reads from the bufferized op.source(). - assert(transferReadOp.getShapedType().isa() && - "only tensor types expected"); Value v = state.lookupBuffer(transferReadOp.source()); transferReadOp.sourceMutable().assign(v); return success(); @@ -2501,6 +2501,8 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto writeOp = cast(op); + if (!writeOp.getShapedType().isa()) + return success(); // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -2509,8 +2511,6 @@ // Create a new transfer_write on buffer that doesn't have a return value. // Leave the previous transfer_write to dead code as it still has uses at // this point. - assert(writeOp.getShapedType().isa() && - "only tensor types expected"); Value resultBuffer = getResultBuffer(b, op->getResult(0), state); if (!resultBuffer) return failure();