diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -520,16 +520,8 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ unsigned bbArgNumber = opOperand->getOperandNumber(); - // Safeguard against the named linalg ops that are manually defined and - // that only support buffer semantics: we should not be there. - // Such ops have an empty regionBuilder and are not constructed with a - // region for now. In the future they are slated to disappear. - assert(this->getOperation()->getNumRegions() == 1 && "unexpected " - "missing region (calling `payloadUsesValueFromOperand` on " - "manually defined named Linalg op?)"); - Block &block = this->getOperation()->getRegion(0).front(); // Init tensors have uses. - return !block.getArgument(bbArgNumber).use_empty(); + return !getBlock()->getArgument(bbArgNumber).use_empty(); }] >, InterfaceMethod< @@ -604,8 +596,7 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - Block &entryBlock = this->getOperation()->getRegion(0).front(); - return entryBlock.getArguments().take_back(this->getNumOutputs()); + return getBlock()->getArguments().take_back(this->getNumOutputs()); }] >, InterfaceMethod< @@ -671,6 +662,21 @@ //===------------------------------------------------------------------===// // Other interface methods. //===------------------------------------------------------------------===// + InterfaceMethod< + /*desc=*/[{ + Return the single block constituting the body of the operation by + calling the getBody method on the concrete operation. + }], + /*retTy=*/"Block*", + /*methodName=*/"getBlock", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Assume the concrete operation implements the + // SingleBlockImplicitTerminator trait. + return $_op.getBody(); + }] + >, InterfaceMethod< /*desc=*/[{ Return the iterator types attribute within the current operation. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -26,14 +26,14 @@ // depending on the specific Linalg op. class LinalgStructuredBase_Op props> : Op { + SingleBlockImplicitTerminator<"YieldOp">, + DeclareOpInterfaceMethods, + LinalgStructuredInterface, + ReifyRankedShapedTypeOpInterface], props)> { code structuredOpsBaseDecls = [{ // Return whether the op accesses the iteration indices. bool hasIndexSemantics() { - Operation *op = this->getOperation(); - if(op->getNumRegions() == 0 || op->getRegion(0).empty()) - return false; - return !op->getRegion(0).front().getOps().empty(); + return !this->getBody()->getOps().empty(); } LogicalResult reifyResultShapes(OpBuilder &b, @@ -45,9 +45,7 @@ } class LinalgStructured_Op props> - : LinalgStructuredBase_Op])> { + : LinalgStructuredBase_Op { code structuredOpsDecls = structuredOpsBaseDecls # [{ std::string getLibraryCallName() { return generateLibraryCallName(getOperation()); @@ -226,10 +224,7 @@ // Generic Linalg ops. //===----------------------------------------------------------------------===// -def GenericOp : LinalgStructuredBase_Op<"generic", [ - AttrSizedOperandSegments, - DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"YieldOp">]> { +def GenericOp : LinalgStructuredBase_Op<"generic", [AttrSizedOperandSegments]> { let description = [{ Generic Linalg op form where the key properties of the computation are specified as attributes. In pretty form, a `linalg.generic` op is written diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -683,17 +683,10 @@ } } - // Named ops that are defined manually have a region builder but no region at - // this time. Assume the region is well-formed by specification. - // TODO: use linalg-ods-gen for all ops when we have enough expressive power. - if (linalgOp->getNumRegions() == 0) { - assert(!linalgOp.getRegionBuilder() && "regionBuilder but no region"); - return success(); - } - - auto ®ion = linalgOp->getRegion(0); - if (linalgOp->getNumRegions() > 1 || !llvm::hasSingleElement(region)) - return op->emitOpError("expected 1 region with 1 block"); + // Check the region has exactly one block. + if (linalgOp->getNumRegions() != 1 || + !llvm::hasSingleElement(linalgOp->getRegion(0))) + return op->emitOpError("expects to have 1 region with 1 block"); if (!linalgOp.getShapesToLoopsMap()) return op->emitOpError("expected the shape-to-loops map to be non-null"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -149,12 +149,8 @@ static void replaceUnitDimIndexOps(GenericOp genericOp, const DenseSet &unitDims, PatternRewriter &rewriter) { - assert(genericOp->getNumRegions() == 1 && - genericOp->getRegion(0).getBlocks().size() == 1 && - "expected generic operation to have one block."); - Block &block = genericOp->getRegion(0).front(); - - for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps())) { + for (IndexOp indexOp : + llvm::make_early_inc_range(genericOp.getBody()->getOps())) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(indexOp); if (unitDims.count(indexOp.dim()) != 0) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -50,27 +50,15 @@ SmallVector resultTypes = namedOp.getOutputTensorTypes(); SmallVector types(resultTypes.begin(), resultTypes.end()); - // Inline the existing region if the named operation has a region attached. - if (namedOp->getNumRegions() == 1) { - GenericOp genericOp = - rewriter.create(namedOp.getLoc(), types, inputOperands, - outputOperands, indexingMaps, iterators); - rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(), - genericOp.region().begin()); - return genericOp; - } - - // Otherwise use the region builder to generate a new region. - // TODO: Remove this path once all linag operations have a region attached. - auto regionBuilder = namedOp.getRegionBuilder(); - assert(regionBuilder && "expect the operation to have region builder"); - return rewriter.create( - namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps, - iterators, - [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) { - ImplicitLocOpBuilder b(loc, bodyBuilder); - regionBuilder(b, *bodyBuilder.getBlock()); - }); + // All named ops have a region attached that can be inlined. + assert(namedOp->getNumRegions() == 1 && + "expect named op to have one region attached"); + GenericOp genericOp = + rewriter.create(namedOp.getLoc(), types, inputOperands, + outputOperands, indexingMaps, iterators); + rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(), + genericOp.region().begin()); + return genericOp; } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -77,15 +77,9 @@ // 4. Transform the index operations by applying the permutation map. if (genericOp.hasIndexSemantics()) { - // TODO: Remove the assertion and add a getBody() method to LinalgOp - // interface once every LinalgOp has a body. - assert(genericOp->getNumRegions() == 1 && - genericOp->getRegion(0).getBlocks().size() == 1 && - "expected generic operation to have one block."); - Block &block = genericOp->getRegion(0).front(); OpBuilder::InsertionGuard guard(rewriter); for (IndexOp indexOp : - llvm::make_early_inc_range(block.getOps())) { + llvm::make_early_inc_range(genericOp.getBody()->getOps())) { rewriter.setInsertionPoint(indexOp); SmallVector allIndices; allIndices.reserve(genericOp.getNumLoops()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -479,8 +479,6 @@ if (!linalgOp.getTiedIndexingMap(opOperand).isIdentity()) return false; } - if (linalgOp->getNumRegions() != 1) - return false; return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0)); } @@ -510,10 +508,7 @@ OpBuilder &b, LinalgOp linalgOp, SmallVectorImpl &newResults, bool broadcastToMaximalCommonShape = false, ArrayRef customVectorizationHooks = {}) { - // 1. Fail to vectorize if the operation does not have one non-empty region. - if (linalgOp->getNumRegions() != 1 || linalgOp->getRegion(0).empty()) - return failure(); - auto &block = linalgOp->getRegion(0).front(); + Block *block = linalgOp.getBlock(); // 2. Values defined above the region can only be broadcast for now. Make them // map to themselves. @@ -533,7 +528,7 @@ // 3. Turn all BBArgs into vector.transfer_read / load. SmallVector indexings; for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - BlockArgument bbarg = block.getArgument(opOperand->getOperandNumber()); + BlockArgument bbarg = block->getArgument(opOperand->getOperandNumber()); if (linalgOp.isScalar(opOperand)) { bvm.map(bbarg, opOperand->get()); continue; @@ -580,7 +575,7 @@ hooks.push_back(vectorizeIndex); // 5. Iteratively call `vectorizeOneOp` to each op in the slice. - for (Operation &op : block.getOperations()) { + for (Operation &op : block->getOperations()) { VectorizationResult result = vectorizeOneOp(b, &op, bvm, hooks); if (result.status == VectorizationStatus::Failure) { LDBG("failed to vectorize: " << op); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -768,12 +768,7 @@ void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp, ArrayRef ivs) { if (tiledOp.hasIndexSemantics()) { - assert(tiledOp->getNumRegions() == 1 && - tiledOp->getRegion(0).getBlocks().size() == 1 && - "expect producer to have one block."); - // Shift all IndexOp results by the tile offset. - Block &block = tiledOp->getRegion(0).front(); - for (IndexOp indexOp : block.getOps()) { + for (IndexOp indexOp : tiledOp.getBlock()->getOps()) { if (ivs[indexOp.dim()] == nullptr) continue; OpBuilder::InsertionGuard guard(b); diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -160,7 +160,7 @@ func @generic_empty_region(%arg0: memref) { %f0 = arith.constant 0.0: f32 - // expected-error @+1 {{op expected 1 region with 1 block}} + // expected-error @+1 {{op expects region #0 to have 0 or 1 blocks}} linalg.generic { indexing_maps = [ affine_map<() -> ()>, affine_map<() -> ()> ], iterator_types = []} @@ -177,7 +177,7 @@ func @generic_empty_region(%arg0: memref) { %f0 = arith.constant 0.0: f32 - // expected-error @+1 {{linalg.generic' op expected 1 region with 1 block}} + // expected-error @+1 {{op expects to have 1 region with 1 block}} linalg.generic { indexing_maps = [ affine_map<() -> ()> , affine_map<() -> ()> ], iterator_types = []} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2303,7 +2303,7 @@ } def TestLinalgConvOp : - TEST_Op<"linalg_conv_op", [AttrSizedOperandSegments, + TEST_Op<"linalg_conv_op", [AttrSizedOperandSegments, SingleBlock, LinalgStructuredInterface, LinalgConvolutionOpInterface]> { let arguments = (ins Variadic:$inputs, diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -443,10 +443,7 @@ // Op definition for {0} //===----------------------------------------------------------------------===// -def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([ - AttrSizedOperandSegments, - DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"YieldOp">], +def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments], /*extraInterfaces=*/[{2}])> { {3} let arguments = (ins