diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -44,6 +44,7 @@ auto executeRegionOp = cast(op); size_t resultNum = std::distance(op->getOpResults().begin(), llvm::find(op->getOpResults(), opResult)); + // TODO: Support multiple blocks. assert(executeRegionOp.getRegion().getBlocks().size() == 1 && "expected exactly 1 block"); auto yieldOp = dyn_cast( @@ -66,13 +67,59 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationState &state) const { - // TODO: Add bufferization support when needed. scf.execute_region should be - // bufferized similar to scf.if. - bool hasTensorReturnType = any_of( - op->getResultTypes(), [](Type t) { return t.isa(); }); - if (hasTensorReturnType) - return op->emitError( - "scf.execute_region with tensor result not supported"); + auto executeRegionOp = cast(op); + + // Compute new result types. + SmallVector newResultTypes; + for (Type type : executeRegionOp->getResultTypes()) { + if (auto rankedTensorType = type.dyn_cast()) { + newResultTypes.push_back(getDynamicMemRefType(rankedTensorType)); + } else if (auto tensorType = type.dyn_cast()) { + newResultTypes.push_back( + getUnrankedMemRefType(tensorType.getElementType())); + } else { + newResultTypes.push_back(type); + } + } + + // Create new op and move over region. + auto newOp = + rewriter.create(op->getLoc(), newResultTypes); + newOp.getRegion().takeBody(executeRegionOp.getRegion()); + + // Update terminator. + assert(newOp.getRegion().getBlocks().size() == 1 && + "only 1 block supported"); + Block *newBlock = &newOp.getRegion().front(); + auto yieldOp = cast(newBlock->getTerminator()); + rewriter.setInsertionPoint(yieldOp); + SmallVector newYieldValues; + for (auto it : llvm::enumerate(yieldOp.getResults())) { + Value val = it.value(); + if (val.getType().isa()) { + newYieldValues.push_back(rewriter.create( + yieldOp.getLoc(), newResultTypes[it.index()], val)); + } else { + newYieldValues.push_back(val); + } + } + rewriter.replaceOpWithNewOp(yieldOp, newYieldValues); + + // Update all uses of the old op. + rewriter.setInsertionPointAfter(newOp); + SmallVector newResults; + for (auto it : llvm::enumerate(executeRegionOp->getResultTypes())) { + if (it.value().isa()) { + newResults.push_back(rewriter.create( + executeRegionOp.getLoc(), newOp->getResult(it.index()))); + } else { + newResults.push_back(newOp->getResult(it.index())); + } + } + + // Replace old op. + rewriter.replaceOp(executeRegionOp, newResults); + return success(); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -159,8 +159,8 @@ // ----- +// expected-error @+1 {{memref return type is unsupported}} func @main() -> tensor<4xi32> { - // expected-error @+1 {{scf.execute_region with tensor result not supported}} %r = scf.execute_region -> tensor<4xi32> { %A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> scf.yield %A: tensor<4xi32> diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -446,6 +446,59 @@ // ----- +// CHECK-LABEL: func @execute_region_test( +// CHECK-SAME: %[[m1:.*]]: memref {linalg.inplaceable = "true"}) + -> (f32, tensor, f32) +{ + %f1 = arith.constant 0.0 : f32 + %f2 = arith.constant 1.0 : f32 + %idx = arith.constant 7 : index + + // scf.execute_region is canonicalized away after bufferization. So just the + // memref.store is left over. + + // CHECK: memref.store %{{.*}}, %[[m1]][%{{.*}}] + %0, %1, %2 = scf.execute_region -> (f32, tensor, f32) { + %t2 = tensor.insert %f2 into %t1[%idx] : tensor + scf.yield %f1, %t2, %f2 : f32, tensor, f32 + } + + // CHECK: return %{{.*}}, %{{.*}} : f32, f32 + return %0, %1, %2 : f32, tensor, f32 +} + +// ----- + +// CHECK-LABEL: func @execute_region_with_conflict( +// CHECK-SAME: %[[m1:.*]]: memref {linalg.inplaceable = "true"}) + -> (f32, tensor, f32) +{ + %f1 = arith.constant 0.0 : f32 + %idx = arith.constant 7 : index + + // scf.execute_region is canonicalized away after bufferization. So just the + // memref.store is left over. + + // CHECK: %[[alloc:.*]] = memref.alloc + // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK: memref.copy %[[m1]], %[[alloc]] + // CHECK: memref.store %{{.*}}, %[[alloc]][%{{.*}}] + %0, %1, %2 = scf.execute_region -> (f32, tensor, f32) { + %t2 = tensor.insert %f1 into %t1[%idx] : tensor + scf.yield %f1, %t2, %f1 : f32, tensor, f32 + } + + // CHECK: %[[load:.*]] = memref.load %[[m1]] + %3 = tensor.extract %t1[%idx] : tensor + + // CHECK: return %{{.*}}, %[[casted]], %[[load]] : f32, memref, f32 + return %0, %1, %3 : f32, tensor, f32 +} + +// ----- + // CHECK: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> // CHECK: func private @some_external_func(memref)