diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h @@ -16,6 +16,9 @@ #include "mlir/IR/Dialect.h" namespace mlir { + +struct TilingResult; + namespace tensor { class PadOp; @@ -39,10 +42,10 @@ /// to guard against the case that we might take a zero-sized slice from the /// original source. For such cases, we `tensor.generate` to generate the /// full tensor. -Operation *bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, - ArrayRef offsets, - ArrayRef sizes, - bool generateZeroSliceGuard = true); +FailureOr bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, + ArrayRef offsets, + ArrayRef sizes, + bool generateZeroSliceGuard = true); /// Registers external models for Tiling interface for tensor ops. /// Currently, it registers: diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -13,6 +13,9 @@ #include "mlir/IR/PatternMatch.h" namespace mlir { + +struct TilingResult; + namespace tensor { /// Populates `patterns` with patterns to wrap a tensor.pad op with an scf.if op @@ -26,7 +29,7 @@ /// provide a mechanism to control where the application happens. With use of /// transform dialect that control is done within the transform dialect. Other /// use cases can inherit from this pattern and add necessary controls. -FailureOr replaceExtractSliceWithTiledProducer( +FailureOr replaceExtractSliceWithTiledProducer( OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp); /// Collects patterns to merge consecutive tensor.insert_slice/extract_slice diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h --- a/mlir/include/mlir/Interfaces/TilingInterface.h +++ b/mlir/include/mlir/Interfaces/TilingInterface.h @@ -21,6 +21,20 @@ #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Support/LLVM.h" +namespace mlir { + +/// Container for result values of tiling. +/// - `tiledOps` contains operations created by the tiling implementation that +/// are returned to the caller for further transformations. +/// - `tiledValues` contains the tiled value corresponding to the result of the +/// untiled operation. +struct TilingResult { + SmallVector tiledOps; + SmallVector tiledValues; +}; + +} // namespace mlir + /// Include the ODS generated interface header files. #include "mlir/Interfaces/TilingInterface.h.inc" diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -63,7 +63,7 @@ The method returns the operation that is the tiled implementation. }], - /*retType=*/"SmallVector", + /*retType=*/"FailureOr", /*methodName=*/"getTiledImplementation", /*args=*/(ins "OpBuilder &":$b, @@ -119,7 +119,7 @@ iteration space). - `sizes` provides the size of the tile. }], - /*retType=*/"FailureOr", + /*retType=*/"FailureOr", /*methodName=*/"generateResultTileValue", /*args=*/(ins "OpBuilder &":$b, diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -431,16 +431,15 @@ /// Find the first "extract" user of `producerOp` and tile it right before its /// use. The tiled op is fused under the `containingOp`. /// Return this fused op on success or nullptr if anything fails. -static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter, - Diagnostic &diag, - Operation *producerOp, - Operation *containingOp) { +static SmallVector +tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, + Operation *producerOp, Operation *containingOp) { LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n"); auto tileableProducer = dyn_cast(producerOp); if (!tileableProducer) { diag.attachNote(producerOp->getLoc()) << "producer is not a TileableInterface: " << *producerOp; - return nullptr; + return {}; } // Search the producer slices accessed within the containing operation. @@ -455,7 +454,7 @@ if (it == tileableProducer->getUsers().end()) { diag.attachNote(tileableProducer->getLoc()) << "could not find fusion opportunity for: " << *tileableProducer; - return nullptr; + return {}; } auto sliceOpToTile = cast(*it); @@ -468,27 +467,29 @@ sliceOpToTile.getSource().cast().getResultNumber(); LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); - FailureOr tiledProducer = tileableProducer.generateResultTileValue( - rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), - sliceOpToTile.getMixedSizes()); - if (failed(tiledProducer)) { + FailureOr tileAndFuseResult = + tileableProducer.generateResultTileValue(rewriter, resultNumber, + sliceOpToTile.getMixedOffsets(), + sliceOpToTile.getMixedSizes()); + if (failed(tileAndFuseResult)) { diag.attachNote(tileableProducer->getLoc()) << "failed to tile producer op: " << *tileableProducer; - return nullptr; + return {}; + } + for (auto tiledOp : tileAndFuseResult->tiledOps) { + LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n"); } - LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n"); // Replace the extract op. - Operation *fusedOp = tiledProducer->getDefiningOp(); auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( - rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber), + rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], sliceOpToTile->getResult(0) .getType() .cast() .getShape()); assert(succeeded(maybeRankReduced) && "unexpected shape"); rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); - return fusedOp; + return tileAndFuseResult->tiledOps; } /// First, find the first "scf::ForallOp" user of `producerOp` and ensure @@ -497,7 +498,8 @@ /// right before its "extract" use. The tiled op is fused under the /// `containingOp`. /// Return this fused op on success or nullptr if anything fails. -static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( +static SmallVector +tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n"); @@ -506,7 +508,7 @@ if (!tileableProducer) { diag.attachNote(producerOp->getLoc()) << "producer is not a TileableInterface: " << *producerOp; - return nullptr; + return {}; } // Search the first use by a "scf::ForallOp" user. @@ -520,7 +522,7 @@ if (!forallOp || forallOp != containingOp) { diag.attachNote(tileableProducer->getLoc()) << "could not find a use by the containing op: " << *tileableProducer; - return nullptr; + return {}; } // Search the producer slices accessed within the containing @@ -542,7 +544,7 @@ if (itBBArgUsers == bbArg.getUsers().end()) { diag.attachNote(containingOp->getLoc()) << "could not find fusion opportunity for bbArg: " << bbArg; - return nullptr; + return {}; } auto sliceOpToTile = cast(*itBBArgUsers); @@ -562,7 +564,7 @@ destinationTensors))) { diag.attachNote(tileableProducer->getLoc()) << "failed to get destination tensors for: " << *tileableProducer; - return nullptr; + return {}; } IRMapping bvm; @@ -573,21 +575,19 @@ llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); }); // Tile the producer. - FailureOr tiledProducer = + FailureOr tileAndFuseResult = tileableProducerClone.generateResultTileValue( rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), sliceOpToTile.getMixedSizes()); - if (failed(tiledProducer)) { + if (failed(tileAndFuseResult)) { diag.attachNote(tileableProducer->getLoc()) << "failed to tile producer op: " << *tileableProducer; - return nullptr; + return {}; } - LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n"); // Replace the extract op. - Operation *fusedOp = tiledProducer->getDefiningOp(); auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( - rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber), + rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], sliceOpToTile->getResult(0) .getType() .cast() @@ -601,7 +601,7 @@ destinationTensors.front()); }); - return fusedOp; + return tileAndFuseResult->tiledOps; } static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, @@ -714,21 +714,21 @@ // cases, we can tile/clone once and reuse the value for each use. // Futhermore, producers should then be traversed according to a // topological sorting. - Operation *tiled = + SmallVector tiledOps = tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp); - if (tiled) { + if (!tiledOps.empty()) { LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp); - fusedOps.push_back(tiled); + fusedOps.append(tiledOps); continue; } - Operation *tiledContainingOpOperand = + SmallVector tiledContainingOpOperand = tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( rewriter, diag, producerOp, containingOp); - if (tiledContainingOpOperand) { + if (!tiledContainingOpOperand.empty()) { LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n" << *containingOp); - fusedOps.push_back(tiledContainingOpOperand); + fusedOps.append(tiledContainingOpOperand); continue; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -41,26 +41,26 @@ offsetsCopy[dimension] = offset; // Create the part as it it were a single tile. - SmallVector tiled = + FailureOr tilingResult = op.getTiledImplementation(b, offsetsCopy, sizesCopy); - assert(tiled.size() == 1 && "expected a single result from tiling"); - auto part = cast(tiled.front()); // Insert the results back and populate the `results` list. - for (auto i : llvm::seq(0, part->getNumResults())) { + for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) { SmallVector resultOffsets, resultSizes; - if (failed(op.getResultTilePosition(b, i, offsetsCopy, sizesCopy, + if (failed(op.getResultTilePosition(b, index, offsetsCopy, sizesCopy, resultOffsets, resultSizes))) return nullptr; SmallVector resultStrides(resultOffsets.size(), b.getIndexAttr(1)); Value inserted = b.create( - loc, part->getResult(i), resultOperands[i], resultOffsets, resultSizes, + loc, result, resultOperands[index], resultOffsets, resultSizes, resultStrides); results.push_back(inserted); } - - return part; + // TODO: this part can be generalized maybe to not expect a single op. + assert(tilingResult->tiledOps.size() == 1 && + "expected split part to return a single tiled operation"); + return cast(tilingResult->tiledOps[0]); } std::pair diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -388,12 +388,13 @@ } // 4. Tile the cloned op and delete the clone. - SmallVector tiledOps = + FailureOr tilingResult = cast(clonedOp).getTiledImplementation(b, tiledOffsets, tiledSizes); b.eraseOp(clonedOp); - assert(tiledOps.size() == 1 && "expected a single produced tiled op"); - tiledOp = tiledOps.front(); + assert(tilingResult->tiledOps.size() == 1 && + "expected a single produced tiled op"); + tiledOp = tilingResult->tiledOps.front(); } // 5. Parallel insert back into the result tensor. @@ -729,12 +730,13 @@ // 5. Tile the cloned op and delete the clone. if (tileSizes.empty()) { - SmallVector tiledOps = + FailureOr tilingResult = cast(clonedOp).getTiledImplementation( b, tiledOffsets, tiledSizes); - assert(tiledOps.size() == 1 && "expected a single produced tiled op"); - tiledOp = tiledOps.front(); - tilingResults = tiledOp->getResults(); + assert(tilingResult->tiledOps.size() == 1 && + "expected a single produced tiled op"); + tiledOp = tilingResult->tiledOps.front(); + tilingResults = tilingResult->tiledValues; } else { LinalgTilingOptions options; FailureOr maybeTiled = tileLinalgOpImpl( diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -111,7 +111,7 @@ } // Instantiate the tiled implementation of the operation. - SmallVector + FailureOr getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef offsets, ArrayRef sizes) const { @@ -129,7 +129,7 @@ Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands); offsetIndices(b, cast(tiledOp), offsets); - return {tiledOp}; + return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; } // Return the details of the output tile generated by the tiled @@ -160,10 +160,10 @@ return success(); } - FailureOr generateResultTileValue(Operation *op, OpBuilder &b, - unsigned resultNumber, - ArrayRef offsets, - ArrayRef sizes) const { + FailureOr + generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes) const { auto linalgOp = cast(op); // Check that the indexing map used for the output is a projected @@ -197,12 +197,15 @@ iterationTileSizes[dimPosition] = sizes[resultExpr.index()]; } - SmallVector tiledOp = tilingInterfaceOp.getTiledImplementation( - b, iterationTileOffsets, iterationTileSizes); - if (tiledOp.size() != 1) + FailureOr tilingResult = + tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets, + iterationTileSizes); + if (tilingResult->tiledOps.size() != 1) return op->emitOpError("failed to generate tiled implementation"); - return tiledOp[0]->getResult(resultNumber); + return TilingResult{ + tilingResult->tiledOps, + SmallVector{tilingResult->tiledValues[resultNumber]}}; } LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -952,12 +952,14 @@ return failure(); } - Operation *tiledPadOp = + FailureOr tilingResult = tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), zeroSliceGuard); + if (failed(tilingResult)) + return failure(); // All shapes are static and the data source is actually used. Rewrite into // pad(extract_slice(x)). - rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); + rewriter.replaceOp(sliceOp, tilingResult->tiledValues); return success(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -251,18 +251,20 @@ /// a destination passing style op. static SmallVector yieldTiledValues(RewriterBase &rewriter, ArrayRef initValues, - Operation *tiledOp, + TilingResult tilingResult, ArrayRef> tileOffsetsList, ArrayRef> tileSizesList, MutableArrayRef loops) { SmallVector replacements = - yieldTiledValues(rewriter, initValues, tiledOp->getResults(), + yieldTiledValues(rewriter, initValues, tilingResult.tiledValues, tileOffsetsList, tileSizesList, loops); - if (auto dstOp = dyn_cast(tiledOp)) { - auto innerMostLoop = loops.back(); - SmallVector tiledOpDestinationTensors = dstOp.getDpsInitOperands(); - updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors, - innerMostLoop.getRegionIterArgs()); + for (auto tiledOp : tilingResult.tiledOps) { + if (auto dstOp = dyn_cast(tiledOp)) { + auto innerMostLoop = loops.back(); + SmallVector tiledOpDestinationTensors = dstOp.getDpsInitOperands(); + updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors, + innerMostLoop.getRegionIterArgs()); + } } return replacements; } @@ -345,9 +347,9 @@ if (!tilingResult.loops.empty()) rewriter.setInsertionPoint( tilingResult.loops.back().getBody()->getTerminator()); - SmallVector tiledImplementation = + FailureOr tiledImplementation = op.getTiledImplementation(rewriter, offsets, sizes); - tilingResult.tiledOps.append(tiledImplementation); + tilingResult.tiledOps.append(tiledImplementation->tiledOps); if (op->getNumResults() == 0) { // nothing more to do. return tilingResult; @@ -356,9 +358,7 @@ // If loops are empty, the tiled op is used as the replacement for the untiled // op. if (tilingResult.loops.empty()) { - tilingResult.replacements = llvm::to_vector( - llvm::map_range(tiledImplementation[0]->getResults(), - [](OpResult result) -> Value { return result; })); + tilingResult.replacements = tiledImplementation->tiledValues; return tilingResult; } @@ -384,7 +384,7 @@ return rewriter.notifyMatchFailure(op, "failed to get destinations"); tilingResult.replacements = yieldTiledValues( - rewriter, destinationTensors, tilingResult.tiledOps.back(), + rewriter, destinationTensors, tiledImplementation.value(), resultOffsetsList, resultSizesList, tilingResult.loops); LLVM_DEBUG({ @@ -523,12 +523,13 @@ // 2. Generate the tiled implementation of the producer of the source OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(candidateSliceOp); - FailureOr fusedProducerValue = + FailureOr tileAndFuseResult = tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, fusableProducer); - if (failed(fusedProducerValue)) + if (failed(tileAndFuseResult)) return std::nullopt; - rewriter.replaceAllUsesWith(candidateSliceOp, fusedProducerValue.value()); + rewriter.replaceAllUsesWith(candidateSliceOp, + tileAndFuseResult->tiledValues[0]); // 3. If the slice is for a destination operand, for example, // @@ -592,8 +593,10 @@ outerMostLoop.setIterArg(iterArgNumber.value(), dstOp.getTiedOpOperand(fusableProducer)->get()); } - if (auto dstOp = fusedProducerValue.value() - .getDefiningOp()) { + for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) { + auto dstOp = dyn_cast(tileAndFusedOp); + if (!dstOp) + continue; scf::ForOp innerMostLoop = loops.back(); updateDestinationOperandsForTiledOp( rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), @@ -601,7 +604,7 @@ } } return scf::SCFFuseProducerOfSliceResult{fusableProducer, - fusedProducerValue.value()}; + tileAndFuseResult->tiledValues[0]}; } /// Reconstruct the fused producer from within the tiled-and-fused code. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -46,15 +46,15 @@ return loopRanges; } - SmallVector + FailureOr getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef offsets, ArrayRef sizes) const { - Operation *result = + FailureOr result = tensor::bubbleUpPadSlice(b, cast(op), offsets, sizes); - if (!result) - return {}; - return {result}; + if (failed(result)) + return failure(); + return result.value(); } LogicalResult @@ -117,7 +117,7 @@ return getPackUnPackIterationDomain(cast(op), b); } - SmallVector + FailureOr getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef offsets, ArrayRef sizes) const { @@ -192,7 +192,8 @@ Operation *tiledPackOp = b.create( loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs()); - return {tiledPackOp}; + return TilingResult{{tiledPackOp}, + SmallVector(tiledPackOp->getResults())}; } LogicalResult @@ -353,7 +354,7 @@ /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we /// can get the actual result. - SmallVector + FailureOr getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef offsets, ArrayRef sizes) const { @@ -412,12 +413,13 @@ loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); if (isPerfectTilingCase) - return {tiledUnpackOp}; + return TilingResult{{tiledUnpackOp}, + SmallVector(tiledUnpackOp->getResults())}; - Operation *extractSlice = + auto extractSlice = b.create(loc, tiledUnpackOp->getResult(0), resultOffsetsFromDest, sizes, destStrides); - return {tiledUnpackOp, extractSlice}; + return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}}; } LogicalResult @@ -431,26 +433,29 @@ return success(); } - FailureOr generateResultTileValue(Operation *op, OpBuilder &b, - unsigned resultNumber, - ArrayRef offsets, - ArrayRef sizes) const { - return getTiledImplementation(op, b, offsets, sizes) - .back() - ->getResult(resultNumber); + FailureOr + generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes) const { + FailureOr tilingResult = + getTiledImplementation(op, b, offsets, sizes); + if (failed(tilingResult)) + return failure(); + return tilingResult.value(); } }; } // namespace -Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, - ArrayRef offsets, - ArrayRef sizes, - bool generateZeroSliceGuard) { +FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, + tensor::PadOp padOp, + ArrayRef offsets, + ArrayRef sizes, + bool generateZeroSliceGuard) { // Only constant padding value supported. Value padValue = padOp.getConstantPaddingValue(); if (!padValue) - return nullptr; + return failure(); // Helper variables and functions for various arithmetic operations. These // are used extensively for computing new offset/length and padding values. @@ -584,10 +589,9 @@ RankedTensorType::get(shape, padOp.getResultType().getElementType()); // Insert cast to ensure that types match. (May be folded away.) - auto castResult = [&](Operation *op) -> Operation * { - Value val = op->getResult(0); + auto castResult = [&](Value val) -> Value { if (resultType == val.getType()) - return op; + return val; return b.create(loc, resultType, val); }; @@ -601,7 +605,7 @@ [&](OpBuilder &builder, Location gLoc, ValueRange indices) { builder.create(gLoc, padValue); }); - return castResult(generateOp); + return generateOp; }; // Emit a SliceOp and a PadOp. Should not be used in cases where @@ -617,30 +621,38 @@ padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm); // Cast result and return. - return castResult(newPadOp); + return newPadOp; }; // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that // the original data source x is not used. - if (hasZeroLen) - return createGenerateOp(); + if (hasZeroLen) { + Operation *generateOp = createGenerateOp(); + return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}}; + } // If there are dynamic dimensions: Generate an scf.if check to avoid // creating SliceOps with result dimensions of size 0 at runtime. if (generateZeroSliceGuard && dynHasZeroLenCond) { + Operation *thenOp; + Operation *elseOp; auto result = b.create( loc, dynHasZeroLenCond, /*thenBuilder=*/ [&](OpBuilder &b, Location loc) { - b.create(loc, createGenerateOp()->getResult(0)); + thenOp = createGenerateOp(); + b.create(loc, castResult(thenOp->getResult(0))); }, /*elseBuilder=*/ [&](OpBuilder &b, Location loc) { - b.create(loc, createPadOfExtractSlice()->getResult(0)); + elseOp = createPadOfExtractSlice(); + b.create(loc, castResult(elseOp->getResult(0))); }); - return result; + return TilingResult{{result}, SmallVector(result->getResults())}; } - return createPadOfExtractSlice(); + + Operation *newPadOp = createPadOfExtractSlice(); + return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}}; } void mlir::tensor::registerTilingInterfaceExternalModels( diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp @@ -20,7 +20,7 @@ using namespace mlir; -FailureOr tensor::replaceExtractSliceWithTiledProducer( +FailureOr tensor::replaceExtractSliceWithTiledProducer( OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) { auto producerOp = dyn_cast(producer.getOwner()); if (!producerOp) @@ -32,7 +32,7 @@ })) return failure(); - FailureOr tiledResult = producerOp.generateResultTileValue( + FailureOr tiledResult = producerOp.generateResultTileValue( builder, producer.getResultNumber(), sliceOp.getMixedOffsets(), sliceOp.getMixedSizes()); if (failed(tiledResult))