diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -298,13 +298,15 @@ /// Insert the given operation at the current insertion point and return it. virtual Operation *insert(Operation *op); - /// Add new block and set the insertion point to the end of it. The block is - /// inserted at the provided insertion point of 'parent'. - Block *createBlock(Region *parent, Region::iterator insertPt = {}); - - /// Add new block and set the insertion point to the end of it. The block is - /// placed before 'insertBefore'. - Block *createBlock(Block *insertBefore); + /// Add new block with 'argTypes' arguments and set the insertion point to the + /// end of it. The block is inserted at the provided insertion point of + /// 'parent'. + virtual Block *createBlock(Region *parent, Region::iterator insertPt = {}, + TypeRange argTypes = llvm::None); + + /// Add new block with 'argTypes' arguments and set the insertion point to the + /// end of it. The block is placed before 'insertBefore'. + Block *createBlock(Block *insertBefore, TypeRange argTypes = llvm::None); /// Returns the current block of the builder. Block *getBlock() const { return block; } diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -344,6 +344,10 @@ /// otherwise an assert will be issued. void eraseOp(Operation *op) override; + /// PatternRewriter hook for creating a new block with the given arguments. + Block *createBlock(Region *parent, Region::iterator insertPt = {}, + TypeRange argTypes = llvm::None) override; + /// PatternRewriter hook for splitting a block into two parts. Block *splitBlock(Block *block, Block::iterator before) override; diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -339,24 +339,28 @@ return op; } -/// Add new block and set the insertion point to the end of it. The block is -/// inserted at the provided insertion point of 'parent'. -Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt) { +/// Add new block with 'argTypes' arguments and set the insertion point to the +/// end of it. The block is inserted at the provided insertion point of +/// 'parent'. +Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt, + TypeRange argTypes) { assert(parent && "expected valid parent region"); if (insertPt == Region::iterator()) insertPt = parent->end(); Block *b = new Block(); + b->addArguments(argTypes); parent->getBlocks().insert(insertPt, b); setInsertionPointToEnd(b); return b; } -/// Add new block and set the insertion point to the end of it. The block is -/// placed before 'insertBefore'. -Block *OpBuilder::createBlock(Block *insertBefore) { +/// Add new block with 'argTypes' arguments and set the insertion point to the +/// end of it. The block is placed before 'insertBefore'. +Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) { assert(insertBefore && "expected valid insertion block"); - return createBlock(insertBefore->getParent(), Region::iterator(insertBefore)); + return createBlock(insertBefore->getParent(), Region::iterator(insertBefore), + argTypes); } /// Create an operation given the fields represented as an OperationState. diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -585,6 +585,9 @@ /// PatternRewriter hook for replacing the results of an operation. void replaceOp(Operation *op, ValueRange newValues); + /// Notifies that a block was created. + void notifyCreatedBlock(Block *block); + /// Notifies that a block was split. void notifySplitBlock(Block *block, Block *continuation); @@ -804,6 +807,10 @@ markNestedOpsIgnored(op); } +void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) { + blockActions.push_back(BlockAction::getCreate(block)); +} + void ConversionPatternRewriterImpl::notifySplitBlock(Block *block, Block *continuation) { blockActions.push_back(BlockAction::getSplit(continuation, block)); @@ -910,6 +917,15 @@ return impl->mapping.lookupOrDefault(key); } +/// PatternRewriter hook for creating a new block with the given arguments. +Block *ConversionPatternRewriter::createBlock(Region *parent, + Region::iterator insertPtr, + TypeRange argTypes) { + Block *block = PatternRewriter::createBlock(parent, insertPtr, argTypes); + impl->notifyCreatedBlock(block); + return block; +} + /// PatternRewriter hook for splitting a block into two parts. Block *ConversionPatternRewriter::splitBlock(Block *block, Block::iterator before) { diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -130,6 +130,19 @@ return %0 : i32 } +// CHECK-LABEL: @create_block +func @create_block() { + "test.container"() ({ + // Check that we created a block with arguments. + // CHECK-NOT: test.create_block + // CHECK: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32): + // CHECK: test.finish + "test.create_block"() : () -> () + "test.finish"() : () -> () + }) : () -> () + return +} + // ----- func @fail_to_convert_illegal_op() -> i32 { @@ -163,3 +176,17 @@ }) : () -> () return } + +// ----- + +// CHECK-LABEL: @create_illegal_block +func @create_illegal_block() { + "test.container"() ({ + // Check that we can undo block creation, i.e. that the block was removed. + // CHECK: test.create_illegal_block + // CHECK-NOT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32): + "test.create_illegal_block"() : () -> () + "test.finish"() : () -> () + }) : () -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -183,6 +183,41 @@ return success(); } }; +/// A simple pattern that creates a block at the end of the parent region of the +/// matched operation. +struct TestCreateBlock : public RewritePattern { + TestCreateBlock(MLIRContext *ctx) + : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + Region ®ion = *op->getParentRegion(); + Type i32Type = rewriter.getIntegerType(32); + rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); + rewriter.create(op->getLoc()); + rewriter.replaceOp(op, {}); + return success(); + } +}; + +/// A simple pattern that creates a block containing an invalid operaiton in +/// order to trigger the block creation undo mechanism. +struct TestCreateIllegalBlock : public RewritePattern { + TestCreateIllegalBlock(MLIRContext *ctx) + : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + Region ®ion = *op->getParentRegion(); + Type i32Type = rewriter.getIntegerType(32); + rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); + // Create an illegal op to ensure the conversion fails. + rewriter.create(op->getLoc(), i32Type); + rewriter.create(op->getLoc()); + rewriter.replaceOp(op, {}); + return success(); + } +}; //===----------------------------------------------------------------------===// // Type-Conversion Rewrite Testing @@ -373,12 +408,12 @@ TestTypeConverter converter; mlir::OwningRewritePatternList patterns; populateWithGenerated(&getContext(), &patterns); - patterns - .insert(&getContext()); + patterns.insert< + TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock, + TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType, + TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, + TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, + TestNonRootReplacement>(&getContext()); patterns.insert(&getContext(), converter); mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter); @@ -388,7 +423,8 @@ // Define the conversion target used for the test. ConversionTarget target(getContext()); target.addLegalOp(); - target.addLegalOp(); + target.addLegalOp(); target .addIllegalOp(); target.addDynamicallyLegalOp([](TestReturnOp op) {