diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp @@ -107,8 +107,8 @@ indices[dim] = adaptor.indices()[dim] + iv; } -static void maybeYieldValue( - bool hasRetVal, OpBuilder builder, Location loc, Value value) { +static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc, + Value value) { if (hasRetVal) { builder.create(loc, value); } else { @@ -150,15 +150,19 @@ auto cond = std_cmpi_sgt(memrefDim.value, memrefIdx); auto check = builder.create( xferOp.getLoc(), resultTypes, cond, - /*thenBuilder=*/[&](OpBuilder &builder, Location loc) { - maybeYieldValue(hasRetVal, builder, loc, inBoundsCase(builder, loc)); - }, /*elseBuilder=*/[&](OpBuilder &builder, Location loc) { - if (outOfBoundsCase) { - maybeYieldValue(hasRetVal, builder, loc, outOfBoundsCase(builder, loc)); - } else { - builder.create(loc); - } - }); + /*thenBuilder=*/ + [&](OpBuilder &builder, Location loc) { + maybeYieldValue(hasRetVal, builder, loc, inBoundsCase(builder, loc)); + }, + /*elseBuilder=*/ + [&](OpBuilder &builder, Location loc) { + if (outOfBoundsCase) { + maybeYieldValue(hasRetVal, builder, loc, + outOfBoundsCase(builder, loc)); + } else { + builder.create(loc); + } + }); return hasRetVal ? check.getResult(0) : Value(); } @@ -176,22 +180,24 @@ function_ref outOfBoundsCase = nullptr) { generateInBoundsCheck( xferOp, iv, builder, dim, /*resultTypes=*/TypeRange(), - /*inBoundsCase=*/[&](OpBuilder &builder, Location loc) { + /*inBoundsCase=*/ + [&](OpBuilder &builder, Location loc) { inBoundsCase(builder, loc); return Value(); }, - /*outOfBoundsCase=*/[&](OpBuilder &builder, Location loc) { + /*outOfBoundsCase=*/ + [&](OpBuilder &builder, Location loc) { if (outOfBoundsCase) - outOfBoundsCase(builder, loc); + outOfBoundsCase(builder, loc); return Value(); }); } /// Given an ArrayAttr, return a copy where the first element is dropped. -static ArrayAttr dropFirstElem(PatternRewriter &rewriter, ArrayAttr attr) { +static ArrayAttr dropFirstElem(OpBuilder &builder, ArrayAttr attr) { if (!attr) return attr; - return ArrayAttr::get(rewriter.getContext(), attr.getValue().drop_front()); + return ArrayAttr::get(builder.getContext(), attr.getValue().drop_front()); } /// Codegen strategy, depending on the operation. @@ -256,8 +262,8 @@ /// /// Note: The loop and type cast are generated in TransferOpConversion. /// The original TransferReadOp and store op are deleted in `cleanup`. - static void rewriteOp(PatternRewriter &rewriter, TransferReadOp xferOp, - Value buffer, Value iv) { + static void rewriteOp(OpBuilder &builder, TransferReadOp xferOp, Value buffer, + Value iv) { SmallVector storeIndices; getStoreIndices(xferOp, storeIndices); storeIndices.push_back(iv); @@ -267,25 +273,25 @@ auto bufferType = buffer.getType().dyn_cast(); auto vecType = bufferType.getElementType().dyn_cast(); - auto inBoundsAttr = dropFirstElem(rewriter, xferOp.in_boundsAttr()); + auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr()); auto newXfer = vector_transfer_read( vecType, xferOp.source(), xferIndices, - AffineMapAttr::get(unpackedPermutationMap(xferOp, rewriter)), + AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), xferOp.padding(), Value(), inBoundsAttr) .value; if (vecType.getRank() > kTargetRank) - newXfer.getDefiningOp()->setAttr(kPassLabel, rewriter.getUnitAttr()); + newXfer.getDefiningOp()->setAttr(kPassLabel, builder.getUnitAttr()); memref_store(newXfer, buffer, storeIndices); } /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write /// padding value to the temporary buffer. - static void handleOutOfBoundsDim( - PatternRewriter &rewriter, TransferReadOp xferOp, Value buffer, - Value iv) { + static void handleOutOfBoundsDim(OpBuilder & /*builder*/, + TransferReadOp xferOp, Value buffer, + Value iv) { SmallVector storeIndices; getStoreIndices(xferOp, storeIndices); storeIndices.push_back(iv); @@ -336,7 +342,7 @@ /// to memory. /// /// Note: For more details, see comments on Strategy. - static void rewriteOp(PatternRewriter &rewriter, TransferWriteOp xferOp, + static void rewriteOp(OpBuilder &builder, TransferWriteOp xferOp, Value buffer, Value iv) { SmallVector loadIndices; getLoadIndices(xferOp, loadIndices); @@ -347,20 +353,19 @@ auto vec = memref_load(buffer, loadIndices); auto vecType = vec.value.getType().dyn_cast(); - auto inBoundsAttr = dropFirstElem(rewriter, xferOp.in_boundsAttr()); + auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr()); auto newXfer = vector_transfer_write( Type(), vec, xferOp.source(), xferIndices, - AffineMapAttr::get(unpackedPermutationMap(xferOp, rewriter)), Value(), + AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), Value(), inBoundsAttr); if (vecType.getRank() > kTargetRank) - newXfer.op->setAttr(kPassLabel, rewriter.getUnitAttr()); + newXfer.op->setAttr(kPassLabel, builder.getUnitAttr()); } /// Handle out-of-bounds accesses on the to-be-unpacked dimension. - static void handleOutOfBoundsDim( - PatternRewriter &rewriter, TransferWriteOp xferOp, Value buffer, - Value iv) {} + static void handleOutOfBoundsDim(OpBuilder &builder, TransferWriteOp xferOp, + Value buffer, Value iv) {} /// Cleanup after rewriting the op. static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) { @@ -499,18 +504,29 @@ auto castedType = unpackOneDim(bufferType); auto casted = vector_type_cast(castedType, buffer); + // Loop bounds and step. auto lb = std_constant_index(0).value; auto ub = std_constant_index( castedType.getDimSize(castedType.getRank() - 1)).value; - affineLoopBuilder(lb, ub, 1, [&](Value iv) { - generateInBoundsCheck( - xferOp, iv, rewriter, unpackedDim(xferOp), - /*inBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) { - Strategy::rewriteOp(rewriter, xferOp, casted, iv); - }, /*outOfBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) { - Strategy::handleOutOfBoundsDim(rewriter, xferOp, casted, iv); - }); - }); + auto step = std_constant_index(1).value; + + // Generate for loop. + rewriter.create( + xferOp.getLoc(), lb, ub, step, ValueRange(), + [&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) { + ScopedContext scope(b, loc); + generateInBoundsCheck( + xferOp, iv, b, unpackedDim(xferOp), + /*inBoundsCase=*/ + [&](OpBuilder &b, Location /*loc*/) { + Strategy::rewriteOp(b, xferOp, casted, iv); + }, + /*outOfBoundsCase=*/ + [&](OpBuilder &b, Location /*loc*/) { + Strategy::handleOutOfBoundsDim(b, xferOp, casted, iv); + }); + b.create(loc); + }); Strategy::cleanup(rewriter, xferOp); return success(); @@ -546,25 +562,25 @@ /// Codegen strategy for TransferReadOp. template <> struct Strategy1d { - static void generateForLoopBody( - OpBuilder &builder, Location loc, TransferReadOp xferOp, Value iv, - ValueRange loopState) { + static void generateForLoopBody(OpBuilder &builder, Location loc, + TransferReadOp xferOp, Value iv, + ValueRange loopState) { SmallVector indices; auto dim = get1dMemrefIndices(xferOp, iv, indices); - auto ivI32 = std_index_cast( - IntegerType::get(builder.getContext(), 32), iv); + auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv); auto vec = loopState[0]; // In case of out-of-bounds access, leave `vec` as is (was initialized with // padding value). auto nextVec = generateInBoundsCheck( xferOp, iv, builder, dim, TypeRange(xferOp.getVectorType()), - /*inBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) { - auto val = memref_load(xferOp.source(), indices); - return vector_insert_element(val, vec, ivI32.value).value; - }, /*outOfBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) { - return vec; - }); + /*inBoundsCase=*/ + [&](OpBuilder & /*b*/, Location loc) { + auto val = memref_load(xferOp.source(), indices); + return vector_insert_element(val, vec, ivI32.value).value; + }, + /*outOfBoundsCase=*/ + [&](OpBuilder & /*b*/, Location loc) { return vec; }); builder.create(loc, nextVec); } @@ -577,27 +593,24 @@ /// Codegen strategy for TransferWriteOp. template <> struct Strategy1d { - static void generateForLoopBody( - OpBuilder &builder, Location loc, TransferWriteOp xferOp, Value iv, - ValueRange /*loopState*/) { + static void generateForLoopBody(OpBuilder &builder, Location loc, + TransferWriteOp xferOp, Value iv, + ValueRange /*loopState*/) { SmallVector indices; auto dim = get1dMemrefIndices(xferOp, iv, indices); - auto ivI32 = std_index_cast( - IntegerType::get(builder.getContext(), 32), iv); + auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv); // Nothing to do in case of out-of-bounds access. generateInBoundsCheck( xferOp, iv, builder, dim, - /*inBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) { - auto val = vector_extract_element(xferOp.vector(), ivI32.value); - memref_store(val, xferOp.source(), indices); - }); + /*inBoundsCase=*/[&](OpBuilder & /*b*/, Location loc) { + auto val = vector_extract_element(xferOp.vector(), ivI32.value); + memref_store(val, xferOp.source(), indices); + }); builder.create(loc); } - static Value initialLoopState(TransferWriteOp xferOp) { - return Value(); - } + static Value initialLoopState(TransferWriteOp xferOp) { return Value(); } }; /// Lower a 1D vector transfer op that operates on a dimension different from @@ -631,11 +644,11 @@ auto map = xferOp.permutation_map(); if (xferOp.getVectorType().getRank() != 1) - return failure(); - if (map.isMinorIdentity()) // Handled by ConvertVectorToLLVM - return failure(); + return failure(); + if (map.isMinorIdentity()) // Handled by ConvertVectorToLLVM + return failure(); if (xferOp.mask()) - return failure(); + return failure(); // Loop bounds, step, state... auto vecType = xferOp.getVectorType(); @@ -648,10 +661,10 @@ rewriter.replaceOpWithNewOp( xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), [&](OpBuilder &builder, Location loc, Value iv, ValueRange loopState) { - ScopedContext nestedScope(builder, loc); - Strategy1d::generateForLoopBody( - builder, loc, xferOp, iv, loopState); - }); + ScopedContext nestedScope(builder, loc); + Strategy1d::generateForLoopBody(builder, loc, xferOp, iv, + loopState); + }); return success(); } @@ -689,3 +702,4 @@ mlir::createProgressiveConvertVectorToSCFPass() { return std::make_unique(); } +