diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -55,7 +55,7 @@ std::unique_ptr builder; }; -inline void defaultRegionBuilder(ArrayRef args) {} +inline void defaultRegionBuilder(ValueRange args) {} /// Build a `linalg.generic` op with the specified `inputs`, `outputs` and /// `region`. @@ -76,8 +76,7 @@ Operation *makeGenericLinalgOp( ArrayRef iteratorTypes, ArrayRef inputs, ArrayRef outputs, - function_ref)> regionBuilder = - defaultRegionBuilder, + function_ref regionBuilder = defaultRegionBuilder, ArrayRef otherValues = {}, ArrayRef otherAttributes = {}); namespace ops { @@ -89,11 +88,11 @@ /// Build the body of a region to compute a scalar multiply, under the current /// ScopedContext, at the current insert point. -void mulRegionBuilder(ArrayRef args); +void mulRegionBuilder(ValueRange args); /// Build the body of a region to compute a scalar multiply-accumulate, under /// the current ScopedContext, at the current insert point. -void macRegionBuilder(ArrayRef args); +void macRegionBuilder(ValueRange args); /// TODO(ntv): In the future we should tie these implementations to something in /// Tablegen that generates the proper interfaces and the proper sugared named @@ -149,7 +148,7 @@ // TODO(ntv): Implement more useful pointwise operations on a per-need basis. -using MatmulRegionBuilder = function_ref args)>; +using MatmulRegionBuilder = function_ref; /// Build a linalg.generic, under the current ScopedContext, at the current /// insert point, that computes: diff --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h @@ -52,6 +52,7 @@ /// Prerequisites: /// All Handles have already captured previously constructed IR objects. BranchOp std_br(BlockHandle bh, ValueRange operands); +BranchOp std_br(Block *block, ValueRange operands); /// Creates a new mlir::Block* and branches to it from the current block. /// Argument types are specified by `operands`. @@ -78,6 +79,8 @@ CondBranchOp std_cond_br(Value cond, BlockHandle trueBranch, ValueRange trueOperands, BlockHandle falseBranch, ValueRange falseOperands); +CondBranchOp std_cond_br(Value cond, Block *trueBranch, ValueRange trueOperands, + Block *falseBranch, ValueRange falseOperands); /// Eagerly creates new mlir::Block* with argument types specified by /// `trueOperands`/`falseOperands`. diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -187,6 +187,8 @@ // trying to enter a Block that has already been constructed. class Append {}; +/// Deprecated. Use buildInNewBlock or appendToBlock instead. +/// /// A BlockBuilder is a NestedBuilder for mlir::Block*. /// This exists by opposition to LoopBuilder which is not related to an /// mlir::Block* but to a mlir::Value. @@ -231,6 +233,8 @@ BlockBuilder &operator=(BlockBuilder &other) = delete; }; +/// Deprecated. Use Block * instead. +/// /// A BlockHandle represents a (potentially "delayed") Block abstraction. /// This extra abstraction is necessary because an mlir::Block is not an /// mlir::Value. @@ -269,6 +273,35 @@ mlir::Block *block; }; +/// Creates a block in the region that contains the insertion block of the +/// OpBuilder currently at the top of ScopedContext stack (appends the block to +/// the region). Be aware that this will NOT update the insertion point of the +/// builder to insert into the newly constructed block. +Block *createBlock(TypeRange argTypes = llvm::None); + +/// Creates a block in the specified region using OpBuilder at the top of +/// ScopedContext stack (appends the block to the region). Be aware that this +/// will NOT update the insertion point of the builder to insert into the newly +/// constructed block. +Block *createBlockInRegion(Region ®ion, TypeRange argTypes = llvm::None); + +/// Calls "builderFn" with ScopedContext reconfigured to insert into "block" and +/// passes in the block arguments. If the block has a terminator, the operations +/// are inserted before the terminator, otherwise appended to the block. +void appendToBlock(Block *block, function_ref builderFn); + +/// Creates a block in the region that contains the insertion block of the +/// OpBuilder currently at the top of ScopedContext stack, and calls "builderFn" +/// to populate the body of the block while passing it the block arguments. +Block *buildInNewBlock(TypeRange argTypes, + function_ref builderFn); + +/// Creates a block in the specified region using OpBuilder at the top of +/// ScopedContext stack, and calls "builderFn" to populate the body of the block +/// while passing it the block arguments. +Block *buildInNewBlock(Region ®ion, TypeRange argTypes, + function_ref builderFn); + /// A StructuredIndexed represents an indexable quantity that is either: /// 1. a captured value, which is suitable for buffer and tensor operands, or; /// 2. a captured type, which is suitable for tensor return values. diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -73,8 +73,8 @@ Operation *mlir::edsc::makeGenericLinalgOp( ArrayRef iteratorTypes, ArrayRef inputs, ArrayRef outputs, - function_ref)> regionBuilder, - ArrayRef otherValues, ArrayRef otherAttributes) { + function_ref regionBuilder, ArrayRef otherValues, + ArrayRef otherAttributes) { for (unsigned i = 0, e = outputs.size(); i + 1 < e; ++i) assert(!(outputs[i].getType().isa() && outputs[i + 1].getType().isa()) && @@ -136,15 +136,12 @@ assert(op->getRegion(0).empty()); OpBuilder opBuilder(op); ScopedContext scope(opBuilder, op->getLoc()); - BlockHandle b; - SmallVector handles(blockTypes.size()); - BlockBuilder(&b, op->getRegion(0), blockTypes, - handles)([&] { regionBuilder(b.getBlock()->getArguments()); }); + buildInNewBlock(op->getRegion(0), blockTypes, regionBuilder); assert(llvm::hasSingleElement(op->getRegion(0))); return op; } -void mlir::edsc::ops::mulRegionBuilder(ArrayRef args) { +void mlir::edsc::ops::mulRegionBuilder(ValueRange args) { using edsc::op::operator+; using edsc::op::operator*; assert(args.size() == 2 && "expected 2 block arguments"); @@ -152,7 +149,7 @@ linalg_yield(a * b); } -void mlir::edsc::ops::macRegionBuilder(ArrayRef args) { +void mlir::edsc::ops::macRegionBuilder(ValueRange args) { using edsc::op::operator+; using edsc::op::operator*; assert(args.size() == 3 && "expected 3 block arguments"); @@ -165,14 +162,14 @@ SmallVector iterTypes(O.getExprs().size(), IteratorType::Parallel); if (O.getType().isa()) { - auto fun = [&unaryOp](ArrayRef args) { + auto fun = [&unaryOp](ValueRange args) { assert(args.size() == 1 && "expected 1 block arguments"); Value a(args[0]); linalg_yield(unaryOp(a)); }; return makeGenericLinalgOp(iterTypes, {I}, {O}, fun); } - auto fun = [&unaryOp](ArrayRef args) { + auto fun = [&unaryOp](ValueRange args) { assert(args.size() == 2 && "expected 2 block arguments"); Value a(args[0]); linalg_yield(unaryOp(a)); @@ -193,14 +190,14 @@ SmallVector iterTypes(O.getExprs().size(), IteratorType::Parallel); if (O.getType().isa()) { - auto fun = [&binaryOp](ArrayRef args) { + auto fun = [&binaryOp](ValueRange args) { assert(args.size() == 2 && "expected 2 block arguments"); Value a(args[0]), b(args[1]); linalg_yield(binaryOp(a, b)); }; return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun); } - auto fun = [&binaryOp](ArrayRef args) { + auto fun = [&binaryOp](ValueRange args) { assert(args.size() == 3 && "expected 3 block arguments"); Value a(args[0]), b(args[1]); linalg_yield(binaryOp(a, b)); diff --git a/mlir/lib/Dialect/StandardOps/EDSC/Intrinsics.cpp b/mlir/lib/Dialect/StandardOps/EDSC/Intrinsics.cpp --- a/mlir/lib/Dialect/StandardOps/EDSC/Intrinsics.cpp +++ b/mlir/lib/Dialect/StandardOps/EDSC/Intrinsics.cpp @@ -18,6 +18,10 @@ return OperationBuilder(bh.getBlock(), ops); } +BranchOp mlir::edsc::intrinsics::std_br(Block *block, ValueRange operands) { + return OperationBuilder(block, operands); +} + BranchOp mlir::edsc::intrinsics::std_br(BlockHandle *bh, ArrayRef types, MutableArrayRef captures, ValueRange operands) { @@ -27,6 +31,14 @@ return OperationBuilder(bh->getBlock(), ops); } +CondBranchOp mlir::edsc::intrinsics::std_cond_br(Value cond, Block *trueBranch, + ValueRange trueOperands, + Block *falseBranch, + ValueRange falseOperands) { + return OperationBuilder(cond, trueBranch, trueOperands, + falseBranch, falseOperands); +} + CondBranchOp mlir::edsc::intrinsics::std_cond_br(Value cond, BlockHandle trueBranch, ValueRange trueOperands, diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -87,6 +87,59 @@ return res; } +Block *mlir::edsc::createBlock(TypeRange argTypes) { + assert(ScopedContext::getContext() != nullptr && "ScopedContext not set up"); + OpBuilder &builder = ScopedContext::getBuilderRef(); + Block *block = builder.getInsertionBlock(); + assert(block != nullptr && + "insertion point not set up in the builder within ScopedContext"); + + return createBlockInRegion(*block->getParent(), argTypes); +} + +Block *mlir::edsc::createBlockInRegion(Region ®ion, TypeRange argTypes) { + assert(ScopedContext::getContext() != nullptr && "ScopedContext not set up"); + OpBuilder &builder = ScopedContext::getBuilderRef(); + + OpBuilder::InsertionGuard guard(builder); + return builder.createBlock(®ion, {}, argTypes); +} + +void mlir::edsc::appendToBlock(Block *block, + function_ref builderFn) { + assert(ScopedContext::getContext() != nullptr && "ScopedContext not set up"); + OpBuilder &builder = ScopedContext::getBuilderRef(); + + OpBuilder::InsertionGuard guard(builder); + if (block->empty() || block->back().isKnownNonTerminator()) + builder.setInsertionPointToEnd(block); + else + builder.setInsertionPoint(&block->back()); + builderFn(block->getArguments()); +} + +Block *mlir::edsc::buildInNewBlock(TypeRange argTypes, + function_ref builderFn) { + assert(ScopedContext::getContext() != nullptr && "ScopedContext not set up"); + OpBuilder &builder = ScopedContext::getBuilderRef(); + Block *block = builder.getInsertionBlock(); + assert(block != nullptr && + "insertion point not set up in the builder within ScopedContext"); + return buildInNewBlock(*block->getParent(), argTypes, builderFn); +} + +Block *mlir::edsc::buildInNewBlock(Region ®ion, TypeRange argTypes, + function_ref builderFn) { + assert(ScopedContext::getContext() != nullptr && "ScopedContext not set up"); + OpBuilder &builder = ScopedContext::getBuilderRef(); + + Block *block = createBlockInRegion(region, argTypes); + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(block); + builderFn(block->getArguments()); + return block; +} + void mlir::edsc::LoopBuilder::operator()(function_ref fun) { // Call to `exit` must be explicit and asymmetric (cannot happen in the // destructor) because of ordering wrt comma operator. diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -182,12 +182,12 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - BlockHandle b1, functionBlock(&f.front()); - BlockBuilder(&b1, {}, {})([&] { std_constant_index(0); }); - BlockBuilder(b1, Append())([&] { std_constant_index(1); }); - BlockBuilder(b1, Append())([&] { std_ret(); }); - // Get back to entry block and add a branch into b1 - BlockBuilder(functionBlock, Append())([&] { std_br(b1, {}); }); + Block *b = + buildInNewBlock(TypeRange(), [&](ValueRange) { std_constant_index(0); }); + appendToBlock(b, [&](ValueRange) { std_constant_index(1); }); + appendToBlock(b, [&](ValueRange) { std_ret(); }); + // Get back to entry block and add a branch into "b". + appendToBlock(&f.front(), [&](ValueRange) { std_br(b, {}); }); // clang-format off // CHECK-LABEL: @builder_blocks @@ -211,28 +211,18 @@ Value c1(std_constant_int(42, 32)), c2(std_constant_int(1234, 32)); ReturnOp ret = std_ret(); - Value r; - Value args12[2]; - Value &arg1 = args12[0], &arg2 = args12[1]; - Value args34[2]; - Value &arg3 = args34[0], &arg4 = args34[1]; - BlockHandle b1, b2, functionBlock(&f.front()); - BlockBuilder(&b1, {c1.getType(), c1.getType()}, args12)( - // b2 has not yet been constructed, need to come back later. - // This is a byproduct of non-structured control-flow. - ); - BlockBuilder(&b2, {c1.getType(), c1.getType()}, args34)([&] { - std_br(b1, {arg3, arg4}); - }); + Block *b1 = createBlock({c1.getType(), c1.getType()}); + Block *b2 = buildInNewBlock({c1.getType(), c1.getType()}, + [&](ValueRange args) { std_br(b1, args); }); // The insertion point within the toplevel function is now past b2, we will // need to get back the entry block. - // This is what happens with unstructured control-flow.. - BlockBuilder(b1, Append())([&] { - r = arg1 + arg2; - std_br(b2, {arg1, r}); + // This is what happens with unstructured control-flow. + appendToBlock(b1, [&](ValueRange args) { + Value r = args[0] + args[1]; + std_br(b2, {args[0], r}); }); - // Get back to entry block and add a branch into b1 - BlockBuilder(functionBlock, Append())([&] { std_br(b1, {c1, c2}); }); + // Get back to entry block and add a branch into b1. + appendToBlock(&f.front(), [&](ValueRange) { std_br(b1, {c1, c2}); }); ret.erase(); // clang-format off @@ -251,68 +241,22 @@ f.erase(); } -TEST_FUNC(builder_blocks_eager) { - using namespace edsc::op; - auto f = makeFunction("builder_blocks_eager"); - - OpBuilder builder(f.getBody()); - ScopedContext scope(builder, f.getLoc()); - Value c1(std_constant_int(42, 32)), c2(std_constant_int(1234, 32)); - Value res; - Value args1And2[2], args3And4[2]; - Value &arg1 = args1And2[0], &arg2 = args1And2[1], &arg3 = args3And4[0], - &arg4 = args3And4[1]; - - // clang-format off - BlockHandle b1, b2; - { // Toplevel function scope. - // Build a new block for b1 eagerly. - std_br(&b1, {c1.getType(), c1.getType()}, args1And2, {c1, c2}); - // Construct a new block b2 explicitly with a branch into b1. - BlockBuilder(&b2, {c1.getType(), c1.getType()}, args3And4)([&]{ - std_br(b1, {arg3, arg4}); - }); - /// And come back to append into b1 once b2 exists. - BlockBuilder(b1, Append())([&]{ - res = arg1 + arg2; - std_br(b2, {arg1, res}); - }); - } - - // CHECK-LABEL: @builder_blocks_eager - // CHECK: %{{.*}} = constant 42 : i32 - // CHECK-NEXT: %{{.*}} = constant 1234 : i32 - // CHECK-NEXT: br ^bb1(%{{.*}}, %{{.*}} : i32, i32) - // CHECK-NEXT: ^bb1(%{{.*}}: i32, %{{.*}}: i32): // 2 preds: ^bb0, ^bb2 - // CHECK-NEXT: %{{.*}} = addi %{{.*}}, %{{.*}} : i32 - // CHECK-NEXT: br ^bb2(%{{.*}}, %{{.*}} : i32, i32) - // CHECK-NEXT: ^bb2(%{{.*}}: i32, %{{.*}}: i32): // pred: ^bb1 - // CHECK-NEXT: br ^bb1(%{{.*}}, %{{.*}} : i32, i32) - // CHECK-NEXT: } - // clang-format on - f.print(llvm::outs()); - f.erase(); -} - TEST_FUNC(builder_cond_branch) { auto f = makeFunction("builder_cond_branch", {}, {IntegerType::get(1, &globalContext())}); OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - Value funcArg(f.getArgument(0)); Value c32(std_constant_int(32, 32)), c64(std_constant_int(64, 64)), c42(std_constant_int(42, 32)); ReturnOp ret = std_ret(); - Value arg1; - Value args23[2]; - BlockHandle b1, b2, functionBlock(&f.front()); - BlockBuilder(&b1, c32.getType(), arg1)([&] { std_ret(); }); - BlockBuilder(&b2, {c64.getType(), c32.getType()}, args23)([&] { std_ret(); }); - // Get back to entry block and add a conditional branch - BlockBuilder(functionBlock, Append())([&] { - std_cond_br(funcArg, b1, {c32}, b2, {c64, c42}); + Block *b1 = buildInNewBlock(c32.getType(), [&](ValueRange) { std_ret(); }); + Block *b2 = buildInNewBlock({c64.getType(), c32.getType()}, + [&](ValueRange) { std_ret(); }); + // Get back to entry block and add a conditional branch. + appendToBlock(&f.front(), [&](ValueRange args) { + std_cond_br(args[0], b1, {c32}, b2, {c64, c42}); }); ret.erase(); @@ -331,44 +275,6 @@ f.erase(); } -TEST_FUNC(builder_cond_branch_eager) { - using namespace edsc::op; - auto f = makeFunction("builder_cond_branch_eager", {}, - {IntegerType::get(1, &globalContext())}); - - OpBuilder builder(f.getBody()); - ScopedContext scope(builder, f.getLoc()); - Value arg0(f.getArgument(0)); - Value c32(std_constant_int(32, 32)), c64(std_constant_int(64, 64)), - c42(std_constant_int(42, 32)); - - // clang-format off - BlockHandle b1, b2; - Value arg1[1], args2And3[2]; - std_cond_br(arg0, - &b1, c32.getType(), arg1, c32, - &b2, {c64.getType(), c32.getType()}, args2And3, {c64, c42}); - BlockBuilder(b1, Append())([]{ - std_ret(); - }); - BlockBuilder(b2, Append())([]{ - std_ret(); - }); - - // CHECK-LABEL: @builder_cond_branch_eager - // CHECK: %{{.*}} = constant 32 : i32 - // CHECK-NEXT: %{{.*}} = constant 64 : i64 - // CHECK-NEXT: %{{.*}} = constant 42 : i32 - // CHECK-NEXT: cond_br %{{.*}}, ^bb1(%{{.*}} : i32), ^bb2(%{{.*}}, %{{.*}} : i64, i32) - // CHECK-NEXT: ^bb1(%{{.*}}: i32): // pred: ^bb0 - // CHECK-NEXT: return - // CHECK-NEXT: ^bb2(%{{.*}}: i64, %{{.*}}: i32): // pred: ^bb0 - // CHECK-NEXT: return - // clang-format on - f.print(llvm::outs()); - f.erase(); -} - TEST_FUNC(builder_helpers) { using namespace edsc::op; auto f32Type = FloatType::getF32(&globalContext()); @@ -433,13 +339,10 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - BlockHandle b1; - // clang-format off std_constant_int(0, 32); - (BlockBuilder(&b1))([]{ - std_constant_int(1, 32); - }); + buildInNewBlock({}, [&](ValueRange) { std_constant_int(1, 32); }); std_constant_int(2, 32); + // clang-format off // CHECK-LABEL: @insertion_in_block // CHECK: {{.*}} = constant 0 : i32 // CHECK: {{.*}} = constant 2 : i32 @@ -1057,7 +960,7 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); Value A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); - auto contractionBuilder = [](ArrayRef args) { + auto contractionBuilder = [](ValueRange args) { assert(args.size() == 3 && "expected 3 block arguments"); (linalg_yield(vector_contraction_matmul(args[0], args[1], args[2]))); };