diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -158,8 +158,8 @@ SmallVector loops; SmallVector tensorResults; }; -Optional tileLinalgOp(OpBuilder &b, LinalgOp op, - const LinalgTilingOptions &options); +FailureOr tileLinalgOp(OpBuilder &b, LinalgOp op, + const LinalgTilingOptions &options); /// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This /// proceeds as follows: @@ -221,7 +221,7 @@ /// The fused loop generated. SmallVector fusedLoops; }; -Optional +FailureOr tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef ops, const LinalgDependenceGraph &dependenceGraph, const LinalgTilingOptions &tilingOptions); @@ -344,7 +344,7 @@ Value fullLocalView; Value partialLocalView; }; -Optional +FailureOr promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView, AllocBufferCallbackFn allocationFn, DataLayout &layout); @@ -359,24 +359,24 @@ /// /// Returns the modified linalg op (the modification happens in place) as well /// as all the copy ops created. -Optional promoteSubViews(OpBuilder &b, LinalgOp op, - LinalgPromotionOptions options); +FailureOr promoteSubViews(OpBuilder &b, LinalgOp op, + LinalgPromotionOptions options); /// Emit a suitable vector form for a Linalg op with fully static shape. LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op, SmallVectorImpl &newResults); /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. -Optional linalgOpToLoops(PatternRewriter &rewriter, - LinalgOp linalgOp); +FailureOr linalgOpToLoops(PatternRewriter &rewriter, + LinalgOp linalgOp); /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. -Optional linalgOpToParallelLoops(PatternRewriter &rewriter, - LinalgOp linalgOp); +FailureOr linalgOpToParallelLoops(PatternRewriter &rewriter, + LinalgOp linalgOp); /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. -Optional linalgOpToAffineLoops(PatternRewriter &rewriter, - LinalgOp linalgOp); +FailureOr linalgOpToAffineLoops(PatternRewriter &rewriter, + LinalgOp linalgOp); //===----------------------------------------------------------------------===// // Preconditions that ensure the corresponding transformation succeeds and can @@ -961,15 +961,15 @@ // TODO: Move lowering to library calls here. return failure(); case LinalgLoweringType::Loops: - if (!linalgOpToLoops(rewriter, op)) + if (failed(linalgOpToLoops(rewriter, op))) return failure(); break; case LinalgLoweringType::AffineLoops: - if (!linalgOpToAffineLoops(rewriter, op)) + if (failed(linalgOpToAffineLoops(rewriter, op))) return failure(); break; case LinalgLoweringType::ParallelLoops: - if (!linalgOpToParallelLoops(rewriter, op)) + if (failed(linalgOpToParallelLoops(rewriter, op))) return failure(); break; } diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -164,25 +164,25 @@ /// Implements the fusion part of the "tileAndFuse on buffers" transformation /// and thus requires the `consumerOpOperand` to be a `subview` op (generally /// obtained by applying the tiling transformation). -Optional fuseProducerOfBuffer(OpBuilder &b, - OpOperand &consumerOpOperand, - const LinalgDependenceGraph &graph); +FailureOr fuseProducerOfBuffer(OpBuilder &b, + OpOperand &consumerOpOperand, + const LinalgDependenceGraph &graph); /// Tensor counterpart of `fuseProducerOfBuffer`. /// This implements the fusion part of the "tileAndFuse on tensors" /// transformation and thus requires the `consumerOpOperand` to be a /// `extract_slice` op (generally obtained by applying the tiling /// transformation). -Optional fuseProducerOfTensor(OpBuilder &b, - OpOperand &consumerOpOperand); +FailureOr fuseProducerOfTensor(OpBuilder &b, + OpOperand &consumerOpOperand); /// Tensor counterpart of `fuseProducerOfBuffer`. /// This implements the fusion part of the "tileAndFuse on tensors" /// transformation and thus requires the `consumerOpOperand` to be a /// `extract_slice` op (generally obtained by applying the tiling /// transformation). Assumes `producerOfTensor` is a Linalg op that produces /// `consumerOpOperand`. -Optional fuseProducerOfTensor(OpBuilder &b, - OpResult producerOpResult, - OpOperand &consumerOpOperand); +FailureOr fuseProducerOfTensor(OpBuilder &b, + OpResult producerOpResult, + OpOperand &consumerOpOperand); //===----------------------------------------------------------------------===// // Fusion on tensor utilities diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -331,7 +331,7 @@ /// For `consumer` with buffer semantics, find the Linalg operation on buffers /// that is the last writer of `consumerOpOperand`. For now the fusable /// dependence is returned as an instance of the `dependenceGraph`. -static Optional +static FailureOr findFusableProducer(OpOperand &consumerOpOperand, const LinalgDependenceGraph &dependenceGraph) { LLVM_DEBUG(llvm::dbgs() << "findFusableProducer for: " @@ -340,7 +340,7 @@ << *consumerOpOperand.getOwner() << "\n"); LinalgOp consumerOp = dyn_cast(consumerOpOperand.getOwner()); if (!consumerOp) - return {}; + return failure(); // Only consider RAW and WAW atm. for (auto depType : { @@ -386,37 +386,37 @@ } } } - return {}; + return failure(); } -Optional +FailureOr mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand, const LinalgDependenceGraph &graph) { Optional fusableDependence = findFusableProducer(consumerOpOperand, graph); if (!fusableDependence) - return llvm::None; + return failure(); LinalgOp producerOp = dyn_cast(fusableDependence->getDependentOp()); if (!producerOp) - return llvm::None; + return failure(); // If producer is already in the same block as consumer, we are done. if (consumerOpOperand.get().getParentBlock() == fusableDependence->getDependentValue().getParentBlock()) - return llvm::None; + return failure(); Optional producerMap = fusableDependence->getDependentOpViewIndexingMap(); if (!producerMap) - return llvm::None; + return failure(); // Must be a subview or an extract_slice to guarantee there are loops we can // fuse into. auto subView = consumerOpOperand.get().getDefiningOp(); if (!subView) { LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)"); - return llvm::None; + return failure(); } // Fuse `producer` just before `consumer`. @@ -459,28 +459,28 @@ } } -Optional +FailureOr mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) { Value inputTensor = consumerOpOperand.get(); OpResult producerOpResult; getProducerOfTensor(inputTensor, producerOpResult); if (!producerOpResult) { LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer"); - return {}; + return failure(); } return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand); } -Optional +FailureOr mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, OpOperand &consumerOpOperand) { auto producerOp = dyn_cast(producerOpResult.getOwner()); if (!producerOp) - return llvm::None; + return failure(); LinalgOp consumerOp = dyn_cast(consumerOpOperand.getOwner()); if (!consumerOp) - return llvm::None; + return failure(); Value inputTensor = consumerOpOperand.get(); @@ -489,13 +489,13 @@ if (!sliceOp) { LLVM_DEBUG(llvm::dbgs() << "\nNot fusable, not an extract_slice op: " << inputTensor); - return {}; + return failure(); } // If producer is already in the same block as consumer, we are done. if (consumerOpOperand.get().getParentBlock() == producerOpResult.getParentBlock()) - return {}; + return failure(); // Insert fused `producer` just before `consumer`. OpBuilder::InsertionGuard g(b); @@ -537,27 +537,27 @@ /// - indexing map of the fused view in the producer : producerIndexMap /// consumerLoopToProducerLoop = /// inverse(producerIndexMap).compose(consumerIndexMap) -static Optional getConsumerLoopToProducerLoopMap( +static FailureOr getConsumerLoopToProducerLoopMap( LinalgDependenceGraph::LinalgDependenceGraphElem dependence) { auto producer = dyn_cast(dependence.getDependentOp()); if (!producer) - return None; + return failure(); Optional producerIndexingMap = dependence.getDependentOpViewIndexingMap(); Optional consumerIndexingMap = dependence.getIndexingOpViewIndexingMap(); if (!producerIndexingMap || !consumerIndexingMap) - return None; + return failure(); AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( producer.iterator_types().getValue(), *producerIndexingMap); if (!prunedProducerIndexingMap.isPermutation()) - return None; + return failure(); if (consumerIndexingMap->getNumResults() != prunedProducerIndexingMap.getNumResults()) - return None; + return failure(); LLVM_DEBUG({ llvm::dbgs() << "\t producerMap : "; @@ -572,7 +572,7 @@ AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap); if (!invProducerIndexMap) - return None; + return failure(); return invProducerIndexMap.compose(*consumerIndexingMap); } @@ -776,7 +776,7 @@ /// Tile the fused loops in the root operation, by setting the tile sizes for /// all other loops to zero (those will be tiled later). -static Optional +static FailureOr tileRootOperation(OpBuilder &b, LinalgOp op, ArrayRef tileSizeVector, const LinalgTilingOptions &options, const std::set &fusedLoops) { @@ -871,12 +871,12 @@ return fusedOps; } -static Optional +static FailureOr tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef ops, const LinalgDependenceGraph &dependenceGraph, const LinalgTilingOptions &tilingOptions) { if (ops.size() < 2) - return llvm::None; + return failure(); LinalgOp rootOp = ops.back(); if (!llvm::all_of( ops, @@ -887,13 +887,13 @@ rootOp.emitError( "unable to fuse operations that have tensor semantics with operations " "that have buffer semantics and viceversa."); - return llvm::None; + return failure(); } // TODO: Support interchange with tile + fuse. This might actually help do // better fusion. if (!tilingOptions.interchangeVector.empty()) { rootOp.emitRemark("unable to handle tile and fuse with interchange"); - return llvm::None; + return failure(); } OpBuilder::InsertionGuard guard(b); @@ -905,7 +905,7 @@ findAllFusableDependences(ops, dependenceGraph); if (fusableDependences.empty()) { LLVM_DEBUG(llvm::dbgs() << "no fusable dependencies found\n"); - return llvm::None; + return failure(); } TiledAndFusedLinalgOps ret; @@ -917,17 +917,17 @@ // just return. if (ret.fusedLoopDims.empty()) { LLVM_DEBUG(llvm::dbgs() << "no fusable loops found\n"); - return llvm::None; + return failure(); } // Tile the fused loops in the last operation in the list. SmallVector tileSizeVector = tilingOptions.tileSizeComputationFunction(b, rootOp); - Optional tiledRootOp = tileRootOperation( + FailureOr tiledRootOp = tileRootOperation( b, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims); - if (!tiledRootOp) { + if (failed(tiledRootOp)) { rootOp.emitRemark("failed to tile the fused loops"); - return llvm::None; + return failure(); } ret.op = tiledRootOp->op; ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); @@ -939,7 +939,7 @@ return ret; } -Optional +FailureOr mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef ops, const LinalgDependenceGraph &dependenceGraph, const LinalgTilingOptions &tilingOptions) { @@ -950,5 +950,5 @@ return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, tilingOptions); default:; } - return llvm::None; + return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -194,8 +194,8 @@ } template -static Optional linalgOpToLoopsImpl(PatternRewriter &rewriter, - LinalgOp linalgOp) { +static FailureOr linalgOpToLoopsImpl(PatternRewriter &rewriter, + LinalgOp linalgOp) { using LoadOpTy = typename std::conditional::value, AffineLoadOp, memref::LoadOp>::type; @@ -227,12 +227,12 @@ SetVector loopSet; for (Value iv : allIvs) { if (!iv) - return {}; + return failure(); // The induction variable is a block argument of the entry block of the // loop operation. BlockArgument ivVal = iv.dyn_cast(); if (!ivVal) - return {}; + return failure(); loopSet.insert(ivVal.getOwner()->getParentOp()); } LinalgLoops loops(loopSet.begin(), loopSet.end()); @@ -253,7 +253,7 @@ auto linalgOp = dyn_cast(op); if (!isa(op)) return failure(); - if (!linalgOpToLoopsImpl(rewriter, linalgOp)) + if (failed(linalgOpToLoopsImpl(rewriter, linalgOp))) return failure(); rewriter.eraseOp(op); return success(); @@ -547,20 +547,20 @@ } /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. -Optional +FailureOr mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, LinalgOp linalgOp) { return linalgOpToLoopsImpl(rewriter, linalgOp); } /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. -Optional mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, - LinalgOp linalgOp) { +FailureOr mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, + LinalgOp linalgOp) { return linalgOpToLoopsImpl(rewriter, linalgOp); } /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. -Optional +FailureOr mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, LinalgOp linalgOp) { return linalgOpToLoopsImpl(rewriter, linalgOp); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -210,7 +210,7 @@ // To account for general boundary effects, padding must be performed on the // boundary tiles. For now this is done with an unconditional `fill` op followed // by a partial `copy` op. -Optional mlir::linalg::promoteSubviewAsNewBuffer( +FailureOr mlir::linalg::promoteSubviewAsNewBuffer( OpBuilder &b, Location loc, memref::SubViewOp subView, AllocBufferCallbackFn allocationFn, DataLayout &layout) { auto viewType = subView.getType(); @@ -236,7 +236,7 @@ // allocating the promoted buffer. Optional fullLocalView = allocationFn(b, subView, fullSizes, layout); if (!fullLocalView) - return {}; + return failure(); SmallVector zeros(fullSizes.size(), b.getIndexAttr(0)); SmallVector ones(fullSizes.size(), b.getIndexAttr(1)); auto partialLocalView = b.createOrFold( @@ -244,21 +244,21 @@ return PromotionInfo{*fullLocalView, partialLocalView}; } -static Optional> +static FailureOr> promoteSubViews(ImplicitLocOpBuilder &b, LinalgOpInstancePromotionOptions options, DataLayout &layout) { if (options.subViews.empty()) - return {}; + return failure(); MapVector promotionInfoMap; for (auto v : options.subViews) { memref::SubViewOp subView = cast(v.second.getDefiningOp()); - Optional promotionInfo = promoteSubviewAsNewBuffer( + auto promotionInfo = promoteSubviewAsNewBuffer( b, b.getLoc(), subView, options.allocationFn, layout); - if (!promotionInfo) - return {}; + if (failed(promotionInfo)) + return failure(); promotionInfoMap[v.first] = *promotionInfo; // Only fill the buffer if the full local view is used @@ -283,7 +283,7 @@ }) .Default([](auto) { return Value(); }); if (!fillVal) - return {}; + return failure(); b.create(fillVal, promotionInfo->fullLocalView); } @@ -295,21 +295,21 @@ if (failed(options.copyInFn( b, cast(v.second.getDefiningOp()), info->second.partialLocalView))) - return {}; + return failure(); } return promotionInfoMap; } -static Optional +static FailureOr promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op, LinalgOpInstancePromotionOptions options, DataLayout &layout) { assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); // 1. Promote the specified views and use them in the new op. auto promotedBuffersAndViews = promoteSubViews(b, options, layout); - if (!promotedBuffersAndViews || + if (failed(promotedBuffersAndViews) || promotedBuffersAndViews->size() != options.subViews.size()) - return {}; + return failure(); // 2. Append all other operands as they appear, this enforces that such // operands are not views. This is to support cases such as FillOp taking @@ -343,7 +343,7 @@ for (auto viewAndPartialLocalView : writebackViews) { if (failed(options.copyOutFn(b, viewAndPartialLocalView.second, viewAndPartialLocalView.first))) - return {}; + return failure(); } // 4. Dealloc all local buffers. @@ -374,13 +374,16 @@ return failure(); } -Optional +FailureOr mlir::linalg::promoteSubViews(OpBuilder &builder, LinalgOp linalgOp, LinalgPromotionOptions options) { LinalgOpInstancePromotionOptions linalgOptions(linalgOp, options); auto layout = DataLayout::closest(linalgOp); ImplicitLocOpBuilder b(linalgOp.getLoc(), builder); - return ::promoteSubViews(b, linalgOp, linalgOptions, layout); + auto res = ::promoteSubViews(b, linalgOp, linalgOptions, layout); + if (failed(res)) + return failure(); + return res; } namespace { @@ -400,7 +403,8 @@ return; LLVM_DEBUG(llvm::dbgs() << "Promote: " << *(op.getOperation()) << "\n"); ImplicitLocOpBuilder b(op.getLoc(), op); - promoteSubViews(b, op, options); + // TODO: signalPassFailure() ? + (void)promoteSubViews(b, op, options); }); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -152,7 +152,7 @@ } template -static Optional +static FailureOr tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes, const LinalgTilingOptions &options) { auto nLoops = op.getNumLoops(); @@ -160,13 +160,13 @@ tileSizes = tileSizes.take_front(nLoops); if (llvm::all_of(tileSizes, isZero)) - return llvm::None; + return failure(); // 1. Build the tiled loop ranges. auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc()); AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); if (!shapeSizesToLoopsMap) - return llvm::None; + return failure(); SmallVector loopRanges; LoopIndexToRangeIndexMap loopIndexToRangeIndex; @@ -291,13 +291,13 @@ } template -Optional static tileLinalgOpImpl( +FailureOr static tileLinalgOpImpl( OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) { OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); if (!options.tileSizeComputationFunction) - return llvm::None; + return failure(); // Enforce the convention that "tiling by zero" skips tiling a particular // dimension. This convention is significantly simpler to handle instead of @@ -313,7 +313,7 @@ return tileLinalgOpImpl(b, op, tileSizeVector, options); } -Optional +FailureOr mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) { switch (options.loopType) { @@ -325,7 +325,7 @@ return tileLinalgOpImpl(b, op, options); default:; } - return llvm::None; + return failure(); } /// Generate a loop nest around a given PadTensorOp (for tiling). `newPadOp` diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -187,26 +187,28 @@ // removed. linalg::Aliases aliases; linalg::LinalgDependenceGraph graph(aliases, linalgOps); - if (auto info = fuseProducerOfBuffer(b, *opOperand, graph)) { - auto *originalOp = info->originalProducer.getOperation(); - eraseSet.insert(originalOp); - auto *originalOpInLinalgOpsVector = - std::find(linalgOps.begin(), linalgOps.end(), originalOp); - *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); - changed = true; - } + auto info = fuseProducerOfBuffer(b, *opOperand, graph); + if (failed(info)) + continue; + auto *originalOp = info->originalProducer.getOperation(); + eraseSet.insert(originalOp); + auto *originalOpInLinalgOpsVector = + std::find(linalgOps.begin(), linalgOps.end(), originalOp); + *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); + changed = true; } else if (opOperand->get().getType().isa()) { // Tile and Fuse tensor input. if (opOperand->getOperandNumber() >= linalgOp.getNumInputs()) continue; - if (auto info = fuseProducerOfTensor(b, *opOperand)) { - auto *originalOp = info->originalProducer.getOperation(); - auto *originalOpInLinalgOpsVector = - std::find(linalgOps.begin(), linalgOps.end(), originalOp); - *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); - // Don't mark for erasure in the tensor case, let DCE handle this. - changed = true; - } + auto info = fuseProducerOfTensor(b, *opOperand); + if (failed(info)) + continue; + auto *originalOp = info->originalProducer.getOperation(); + auto *originalOpInLinalgOpsVector = + std::find(linalgOps.begin(), linalgOps.end(), originalOp); + *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); + // Don't mark for erasure in the tensor case, let DCE handle this. + changed = true; } } }