diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp @@ -194,10 +194,10 @@ } /// Given an ArrayAttr, return a copy where the first element is dropped. -static ArrayAttr dropFirstElem(PatternRewriter &rewriter, ArrayAttr attr) { +static ArrayAttr dropFirstElem(OpBuilder &builder, ArrayAttr attr) { if (!attr) return attr; - return ArrayAttr::get(rewriter.getContext(), attr.getValue().drop_front()); + return ArrayAttr::get(builder.getContext(), attr.getValue().drop_front()); } /// Codegen strategy, depending on the operation. @@ -261,8 +261,8 @@ /// /// Note: The loop and type cast are generated in TransferOpConversion. /// The original TransferReadOp and store op are deleted in `cleanup`. - static void rewriteOp(PatternRewriter &rewriter, TransferReadOp xferOp, - Value buffer, Value iv) { + static void rewriteOp(OpBuilder &builder, TransferReadOp xferOp, Value buffer, + Value iv) { SmallVector storeIndices; getStoreIndices(xferOp, storeIndices); storeIndices.push_back(iv); @@ -272,23 +272,23 @@ auto bufferType = buffer.getType().dyn_cast(); auto vecType = bufferType.getElementType().dyn_cast(); - auto inBoundsAttr = dropFirstElem(rewriter, xferOp.in_boundsAttr()); + auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr()); auto newXfer = vector_transfer_read( vecType, xferOp.source(), xferIndices, - AffineMapAttr::get(unpackedPermutationMap(xferOp, rewriter)), + AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), xferOp.padding(), Value(), inBoundsAttr) .value; if (vecType.getRank() > kTargetRank) - newXfer.getDefiningOp()->setAttr(kPassLabel, rewriter.getUnitAttr()); + newXfer.getDefiningOp()->setAttr(kPassLabel, builder.getUnitAttr()); memref_store(newXfer, buffer, storeIndices); } /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write /// padding value to the temporary buffer. - static void handleOutOfBoundsDim(PatternRewriter &rewriter, + static void handleOutOfBoundsDim(OpBuilder & /*builder*/, TransferReadOp xferOp, Value buffer, Value iv) { SmallVector storeIndices; @@ -341,7 +341,7 @@ /// to memory. /// /// Note: For more details, see comments on Strategy. - static void rewriteOp(PatternRewriter &rewriter, TransferWriteOp xferOp, + static void rewriteOp(OpBuilder &builder, TransferWriteOp xferOp, Value buffer, Value iv) { SmallVector loadIndices; getLoadIndices(xferOp, loadIndices); @@ -352,20 +352,19 @@ auto vec = memref_load(buffer, loadIndices); auto vecType = vec.value.getType().dyn_cast(); - auto inBoundsAttr = dropFirstElem(rewriter, xferOp.in_boundsAttr()); + auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr()); auto newXfer = vector_transfer_write( Type(), vec, xferOp.source(), xferIndices, - AffineMapAttr::get(unpackedPermutationMap(xferOp, rewriter)), Value(), + AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), Value(), inBoundsAttr); if (vecType.getRank() > kTargetRank) - newXfer.op->setAttr(kPassLabel, rewriter.getUnitAttr()); + newXfer.op->setAttr(kPassLabel, builder.getUnitAttr()); } /// Handle out-of-bounds accesses on the to-be-unpacked dimension. - static void handleOutOfBoundsDim(PatternRewriter &rewriter, - TransferWriteOp xferOp, Value buffer, - Value iv) {} + static void handleOutOfBoundsDim(OpBuilder &builder, TransferWriteOp xferOp, + Value buffer, Value iv) {} /// Cleanup after rewriting the op. static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) { @@ -503,22 +502,30 @@ auto castedType = unpackOneDim(bufferType); auto casted = vector_type_cast(castedType, buffer); + // Loop bounds and step. auto lb = std_constant_index(0).value; auto ub = std_constant_index(castedType.getDimSize(castedType.getRank() - 1)) .value; - affineLoopBuilder(lb, ub, 1, [&](Value iv) { - generateInBoundsCheck( - xferOp, iv, rewriter, unpackedDim(xferOp), - /*inBoundsCase=*/ - [&](OpBuilder & /*b*/, Location loc) { - Strategy::rewriteOp(rewriter, xferOp, casted, iv); - }, - /*outOfBoundsCase=*/ - [&](OpBuilder & /*b*/, Location loc) { - Strategy::handleOutOfBoundsDim(rewriter, xferOp, casted, iv); - }); - }); + auto step = std_constant_index(1).value; + + // Generate for loop. + rewriter.create( + xferOp.getLoc(), lb, ub, step, ValueRange(), + [&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) { + ScopedContext scope(b, loc); + generateInBoundsCheck( + xferOp, iv, b, unpackedDim(xferOp), + /*inBoundsCase=*/ + [&](OpBuilder &b, Location /*loc*/) { + Strategy::rewriteOp(b, xferOp, casted, iv); + }, + /*outOfBoundsCase=*/ + [&](OpBuilder &b, Location /*loc*/) { + Strategy::handleOutOfBoundsDim(b, xferOp, casted, iv); + }); + b.create(loc); + }); Strategy::cleanup(rewriter, xferOp); return success();