diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -288,7 +288,7 @@ /// Generic vectorization function that rewrites the body of a `linalgOp` into /// vector form. Generic vectorization proceeds as follows: -/// 1. The region for the linalg op is created if necessary. +/// 1. Verify the `linalgOp` has one non-empty region. /// 2. Values defined above the region are mapped to themselves and will be /// broadcasted on a per-need basis by their consumers. /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d @@ -299,36 +299,21 @@ LogicalResult vectorizeAsLinalgGeneric( OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl &newResults, ArrayRef customVectorizationHooks = {}) { - // 1. Certain Linalg ops do not have a region but only a region builder. - // If so, build the region so we can vectorize. - std::unique_ptr owningRegion; - Region *region; - if (linalgOp->getNumRegions() > 0) { - region = &linalgOp->getRegion(0); - } else { - // RAII avoid remaining in block. - OpBuilder::InsertionGuard g(builder); - owningRegion = std::make_unique(); - region = owningRegion.get(); - Block *block = builder.createBlock(region); - auto elementTypes = llvm::to_vector<4>( - llvm::map_range(linalgOp.getShapedOperandTypes(), - [](ShapedType t) { return t.getElementType(); })); - block->addArguments(elementTypes); - linalgOp.getRegionBuilder()(*block, /*captures=*/{}); - } - Block *block = ®ion->front(); + // 1. Fail to vectorize if the operation does not have one non-empty region. + if (linalgOp->getNumRegions() != 1 || linalgOp->getRegion(0).empty()) + return failure(); + auto &block = linalgOp->getRegion(0).front(); BlockAndValueMapping bvm; // 2. Values defined above the region can only be broadcast for now. Make them // map to themselves. llvm::SetVector valuesSet; - mlir::getUsedValuesDefinedAbove(*region, valuesSet); + mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet); bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef()); // 3. Turn all BBArgs into vector.transfer_read / load. SmallVector indexings; - for (auto bbarg : block->getArguments()) { + for (auto bbarg : block.getArguments()) { Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber()); AffineMap map; VectorType vectorType = extractVectorTypeFromShapedValue(vectorArg); @@ -360,7 +345,7 @@ hooks.push_back(vectorizeYield); // 5. Iteratively call `vectorizeOneOp` to each op in the slice. - for (Operation &op : block->getOperations()) { + for (Operation &op : block.getOperations()) { VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks); if (result.status == VectorizationStatus::Failure) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);