diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -1114,16 +1114,19 @@ if (isa(op)) return success(); + // Check if op has tensor results or operands. + auto isaTensor = [](Type t) { return t.isa(); }; + bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); + bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); + if (!hasTensorResult && !hasTensorOperand) + return success(); + // Bufferize using `BufferizableOpInterface`. if (auto bufferizableOp = dyn_cast(op)) return bufferizableOp.bufferize(b, state); // Other op with tensors. No bufferization method specified. - auto isaTensor = [](Type t) { return t.isa(); }; - if (any_of(op->getOperandTypes(), isaTensor) || - any_of(op->getResultTypes(), isaTensor)) - return op->emitError() << "unsupported op with tensors"; - return success(); + return op->emitError() << "unsupported op with tensors"; } static LogicalResult bufferizeFuncOpInternals( @@ -2482,10 +2485,9 @@ OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); - if (transferReadOp.getShapedType().isa()) - return failure(); - // TransferReadOp always reads from the bufferized op.source(). + assert(transferReadOp.getShapedType().isa() && + "only tensor types expected"); Value v = state.lookupBuffer(transferReadOp.source()); transferReadOp.sourceMutable().assign(v); return success(); @@ -2530,12 +2532,11 @@ OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); - if (writeOp.getShapedType().isa()) - return failure(); - // Create a new transfer_write on buffer that doesn't have a return value. // Leave the previous transfer_write to dead code as it still has uses at // this point. + assert(writeOp.getShapedType().isa() && + "only tensor types expected"); Value resultBuffer = getResultBuffer(b, op->getResult(0), state); if (!resultBuffer) return failure(); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -167,16 +167,3 @@ } return %r: tensor<4xi32> } - -// ----- - -func @main() -> i32 { - %c0 = arith.constant 0: index - // expected-error @+1 {{expected result-less scf.execute_region containing op}} - %r = scf.execute_region -> i32 { - %A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> - %e = tensor.extract %A[%c0]: tensor<4xi32> - scf.yield %e: i32 - } - return %r: i32 -}