diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -251,6 +251,16 @@ /// not yet bound to mlir::Value. BlockBuilder(BlockHandle *bh, ArrayRef args); + /// Constructs a new mlir::Block with argument types derived from `args` and + /// appends it as the last block in the region. + /// Captures the new block in `bh` and its arguments into `args`. + /// Enters the new mlir::Block* and sets the insertion point to its end. + /// + /// Prerequisites: + /// The ValueHandle `args` are typed delayed ValueHandles; i.e. they are + /// not yet bound to mlir::Value. + BlockBuilder(BlockHandle *bh, Region ®ion, ArrayRef args); + /// The only purpose of this operator is to serve as a sequence point so that /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is /// scoped within a BlockBuilder. @@ -450,6 +460,9 @@ /// Delegates block creation to MLIR and wrap the resulting mlir::Block. static BlockHandle create(ArrayRef argTypes); + /// Delegates block creation to MLIR and wrap the resulting mlir::Block. + static BlockHandle createInRegion(Region ®ion, ArrayRef argTypes); + operator bool() { return block != nullptr; } operator mlir::Block *() { return block; } mlir::Block *getBlock() { return block; } diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -184,14 +184,16 @@ ? getElementTypeOrSelf(it.value()) : it.value().getType()); - assert(op->getRegions().front().empty()); - op->getRegions().front().push_front(new Block); - OpBuilder bb(op->getRegions().front()); - ScopedContext scope(bb, op->getLoc()); + assert(op->getNumRegions() == 1); + assert(op->getRegion(0).empty()); + OpBuilder opBuilder(op); + ScopedContext scope(opBuilder, op->getLoc()); BlockHandle b; auto handles = makeValueHandles(blockTypes); - BlockBuilder(&b, makeHandlePointers(MutableArrayRef(handles)))( + BlockBuilder(&b, op->getRegion(0), + makeHandlePointers(MutableArrayRef(handles)))( [&] { regionBuilder(b.getBlock()->getArguments()); }); + assert(op->getRegion(0).getBlocks().size() == 1); return op; } diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -133,6 +133,22 @@ return res; } +BlockHandle mlir::edsc::BlockHandle::createInRegion(Region ®ion, + ArrayRef argTypes) { + auto ¤tB = ScopedContext::getBuilder(); + BlockHandle res; + region.push_back(new Block); + res.block = ®ion.back(); + // createBlock sets the insertion point inside the block. + // We do not want this behavior when using declarative builders with nesting. + OpBuilder::InsertionGuard g(currentB); + currentB.setInsertionPoint(res.block, res.block->begin()); + for (auto t : argTypes) { + res.block->addArgument(t); + } + return res; +} + static Optional emitStaticFor(ArrayRef lbs, ArrayRef ubs, int64_t step) { @@ -285,6 +301,23 @@ enter(bh->getBlock()); } +mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh, Region ®ion, + ArrayRef args) { + assert(!*bh && "BlockHandle already captures a block, use " + "the explicit BockBuilder(bh, Append())({}) syntax instead."); + SmallVector types; + for (auto *a : args) { + assert(!a->hasValue() && + "Expected delayed ValueHandle that has not yet captured."); + types.push_back(a->getType()); + } + *bh = BlockHandle::createInRegion(region, types); + for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) { + *(std::get<0>(it)) = ValueHandle(std::get<1>(it)); + } + enter(bh->getBlock()); +} + /// Only serves as an ordering point between entering nested block and creating /// stmts. void mlir::edsc::BlockBuilder::operator()(function_ref fun) { diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -876,7 +876,7 @@ // CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, // CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} -/// CHECK: ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): +/// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): // CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32 // CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32 // CHECK: linalg.yield %[[a4]] : f32 @@ -906,7 +906,7 @@ // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d1)>, // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d1)>], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} -/// CHECK: ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): +/// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): // CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32 // CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32 // CHECK: linalg.yield %[[a4]] : f32 @@ -937,7 +937,7 @@ // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d2, d1)>, // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3, d4, d1 + d2 * 7)>], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} -// CHECK: ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): +// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): // CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32 // CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32 // CHECK: linalg.yield %[[a4]] : f32