diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -521,6 +521,8 @@ "Perform full unrolling when converting vector transfers to SCF">, Option<"targetRank", "target-rank", "unsigned", /*default=*/"1", "Target vector rank to which transfer ops should be lowered">, + Option<"lowerTensors", "lower-tensors", "bool", /*default=*/"false", + "Lower transfer ops that operate on tensors"> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h --- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h +++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h @@ -50,6 +50,7 @@ struct VectorTransferToSCFOptions { bool unroll = false; unsigned targetRank = 1; + bool lowerTensors = false; VectorTransferToSCFOptions &setUnroll(bool u) { unroll = u; @@ -60,6 +61,11 @@ targetRank = r; return *this; } + + VectorTransferToSCFOptions &setLowerTensors(bool l) { + lowerTensors = l; + return *this; + } }; /// Collect a set of patterns to convert from the Vector dialect to SCF + std. diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -98,6 +98,7 @@ static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc, Value value) { if (hasRetVal) { + assert(value && "Expected non-empty value"); builder.create(loc, value); } else { builder.create(loc); @@ -238,6 +239,19 @@ newXferOp->setAttr(kPassLabel, builder.getUnitAttr()); } +/// Return true if this transfer op operates on a source tensor. +template +static bool isTensorOp(OpTy xferOp) { + if (xferOp.getShapedType().template isa()) { + if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) { + // TransferWriteOps on tensors have a result. + assert(xferOp->getNumResults() > 0); + } + return true; + } + return false; +} + namespace lowering_n_d { /// Helper data structure for data and mask buffers. @@ -360,8 +374,8 @@ /// Note: The `mask` operand is set in TransferOpConversion. static TransferReadOp rewriteOp(OpBuilder &builder, VectorTransferToSCFOptions options, - TransferReadOp xferOp, Value buffer, - Value iv) { + TransferReadOp xferOp, Value buffer, Value iv, + ValueRange /*loopState*/) { SmallVector storeIndices; getBufferIndices(xferOp, storeIndices); storeIndices.push_back(iv); @@ -389,9 +403,9 @@ /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write /// padding value to the temporary buffer. - static void handleOutOfBoundsDim(OpBuilder & /*builder*/, - TransferReadOp xferOp, Value buffer, - Value iv) { + static Value handleOutOfBoundsDim(OpBuilder & /*builder*/, + TransferReadOp xferOp, Value buffer, + Value iv, ValueRange /*loopState*/) { SmallVector storeIndices; getBufferIndices(xferOp, storeIndices); storeIndices.push_back(iv); @@ -400,13 +414,19 @@ auto vecType = bufferType.getElementType().dyn_cast(); auto vec = std_splat(vecType, xferOp.padding()); memref_store(vec, buffer, storeIndices); + + return Value(); } /// Cleanup after rewriting the op. - static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp) { + static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp, + scf::ForOp /*forOp*/) { rewriter.eraseOp(getStoreOp(xferOp)); rewriter.eraseOp(xferOp); } + + /// Return the initial loop state for the generated scf.for loop. + static Value initialLoopState(TransferReadOp xferOp) { return Value(); } }; /// Codegen strategy for vector TransferWriteOp. @@ -445,7 +465,7 @@ static TransferWriteOp rewriteOp(OpBuilder &builder, VectorTransferToSCFOptions options, TransferWriteOp xferOp, Value buffer, - Value iv) { + Value iv, ValueRange loopState) { SmallVector loadIndices; getBufferIndices(xferOp, loadIndices); loadIndices.push_back(iv); @@ -455,8 +475,10 @@ auto vec = memref_load(buffer, loadIndices); auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr()); + auto source = loopState.empty() ? xferOp.source() : loopState[0]; + Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); auto newXfer = vector_transfer_write( - Type(), vec, xferOp.source(), xferIndices, + type, vec, source, xferIndices, AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), Value(), inBoundsAttr); @@ -466,20 +488,37 @@ } /// Handle out-of-bounds accesses on the to-be-unpacked dimension. - static void handleOutOfBoundsDim(OpBuilder &builder, TransferWriteOp xferOp, - Value buffer, Value iv) {} + static Value handleOutOfBoundsDim(OpBuilder &builder, TransferWriteOp xferOp, + Value buffer, Value iv, + ValueRange loopState) { + return isTensorOp(xferOp) ? loopState[0] : Value(); + } /// Cleanup after rewriting the op. - static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) { - rewriter.eraseOp(xferOp); + static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp, + scf::ForOp forOp) { + if (isTensorOp(xferOp)) { + assert(forOp->getNumResults() == 1 && "Expected one for loop result"); + rewriter.replaceOp(xferOp, forOp->getResult(0)); + } else { + rewriter.eraseOp(xferOp); + } + } + + /// Return the initial loop state for the generated scf.for loop. + static Value initialLoopState(TransferWriteOp xferOp) { + return isTensorOp(xferOp) ? xferOp.source() : Value(); } }; template -LogicalResult checkPrepareXferOp(OpTy xferOp, unsigned targetRank) { +LogicalResult checkPrepareXferOp(OpTy xferOp, + VectorTransferToSCFOptions options) { if (xferOp->hasAttr(kPassLabel)) return failure(); - if (xferOp.getVectorType().getRank() <= targetRank) + if (xferOp.getVectorType().getRank() <= options.targetRank) + return failure(); + if (isTensorOp(xferOp) && !options.lowerTensors) return failure(); return success(); } @@ -513,7 +552,7 @@ LogicalResult matchAndRewrite(TransferReadOp xferOp, PatternRewriter &rewriter) const override { - if (checkPrepareXferOp(xferOp, options.targetRank).failed()) + if (checkPrepareXferOp(xferOp, options).failed()) return failure(); ScopedContext scope(rewriter, xferOp.getLoc()); @@ -561,7 +600,7 @@ LogicalResult matchAndRewrite(TransferWriteOp xferOp, PatternRewriter &rewriter) const override { - if (checkPrepareXferOp(xferOp, options.targetRank).failed()) + if (checkPrepareXferOp(xferOp, options).failed()) return failure(); ScopedContext scope(rewriter, xferOp.getLoc()); @@ -599,6 +638,18 @@ /// corresponding Strategy. If the to-be-unpacked dimension can be /// out-of-bounds, generate an if-check and handle both cases separately. /// 3. Clean up according to the corresponding Strategy. +/// +/// Note: If the transfer op is a TransferWriteOp and operates on a tensor +/// source (as opposed to a memref source), then each iteration of the generated +/// scf.for loop yields the new tensor value. E.g.: +/// ``` +/// %result = scf.for i = 0 to 5 { +/// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>> +/// %1 = vector.transfer_write %0, %source[...] +/// : vector<4x3xf32>, tensor<5x4x3xf32> +/// scf.yield %1 : tensor<5x4x3xf32> +/// } +/// ``` template struct TransferOpConversion : public VectorToSCFPattern { using VectorToSCFPattern::VectorToSCFPattern; @@ -641,19 +692,26 @@ castedDataType.getDimSize(castedDataType.getRank() - 1)) .value; auto step = std_constant_index(1).value; + // TransferWriteOps that operate on tensors return the modified tensor and + // require a loop state. + auto loopState = Strategy::initialLoopState(xferOp); // Generate for loop. - rewriter.create( - xferOp.getLoc(), lb, ub, step, ValueRange(), - [&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) { + auto result = rewriter.create( + xferOp.getLoc(), lb, ub, step, + loopState ? ValueRange(loopState) : ValueRange(), + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { ScopedContext scope(b, loc); - generateInBoundsCheck( + Type stateType = loopState.empty() ? Type() : loopState[0].getType(); + + auto result = generateInBoundsCheck( xferOp, iv, b, unpackedDim(xferOp), + stateType ? TypeRange(stateType) : TypeRange(), /*inBoundsCase=*/ [&](OpBuilder &b, Location /*loc*/) { // Create new transfer op. OpTy newXfer = Strategy::rewriteOp( - b, this->options, xferOp, castedDataBuffer, iv); + b, this->options, xferOp, castedDataBuffer, iv, loopState); // If old transfer op has a mask: Set mask on new transfer op. // Special case: If the mask of the old transfer op is 1D and @@ -676,16 +734,19 @@ rewriter.updateRootInPlace( newXfer, [&]() { newXfer.maskMutable().assign(mask); }); } + + return loopState.empty() ? Value() : newXfer->getResult(0); }, /*outOfBoundsCase=*/ [&](OpBuilder &b, Location /*loc*/) { - Strategy::handleOutOfBoundsDim(b, xferOp, - castedDataBuffer, iv); + return Strategy::handleOutOfBoundsDim( + b, xferOp, castedDataBuffer, iv, loopState); }); - b.create(loc); + + maybeYieldValue(!loopState.empty(), b, loc, result); }); - Strategy::cleanup(rewriter, xferOp); + Strategy::cleanup(rewriter, xferOp, result); return success(); } }; @@ -1160,12 +1221,14 @@ ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { this->fullUnroll = options.unroll; this->targetRank = options.targetRank; + this->lowerTensors = options.lowerTensors; } void runOnFunction() override { VectorTransferToSCFOptions options; options.setUnroll(fullUnroll); options.setTargetRank(targetRank); + options.setLowerTensors(lowerTensors); RewritePatternSet patterns(getFunction().getContext()); populateVectorToSCFConversionPatterns(patterns, options); diff --git a/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir b/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s -convert-vector-to-scf='lower-tensors=true' -split-input-file -allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: func @transfer_read_2d( +// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref> +// CHECK: %[[CASTED:.*]] = vector.type_cast %[[ALLOC]] : memref> to memref<4xvector<9xf32>> +// CHECK: scf.for {{.*}} { +// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %cst {in_bounds = [true]} : tensor, vector<9xf32> +// CHECK: memref.store %[[READ]], %[[CASTED]][%{{.*}}] : memref<4xvector<9xf32>> +// CHECK: } +// CHECK: %[[LOADED:.*]] = memref.load %[[ALLOC]][] : memref> +// CHECK: return %[[LOADED]] : vector<4x9xf32> +func @transfer_read_2d(%A : tensor, %base1 : index, %base2 : index) + -> (vector<4x9xf32>){ + %p = constant -42.0: f32 + %f = vector.transfer_read %A[%base1, %base2], %p {in_bounds = [true, true]} + : tensor, vector<4x9xf32> + return %f : vector<4x9xf32> +} + +// ----- + +// CHECK-LABEL: func @transfer_write_2d( +// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref> +// CHECK: memref.store {{.*}}, %[[ALLOC]][] : memref> +// CHECK: %[[CASTED:.*]] = vector.type_cast %[[ALLOC]] : memref> to memref<2xvector<3xf32>> +// CHECK: %[[RESULT:.*]] = scf.for {{.*}} iter_args(%[[STATE:.*]] = %{{.*}}) -> (tensor) { +// CHECK: %[[LOADED:.*]] = memref.load %[[CASTED]][%{{.*}}] : memref<2xvector<3xf32>> +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[LOADED]], %[[STATE]][{{.*}}] {in_bounds = [true]} : vector<3xf32>, tensor +// CHECK: scf.yield %[[WRITE]] : tensor +// CHECK: } +// CHECK: return %[[RESULT]] : tensor +func @transfer_write_2d(%A : tensor, %vec : vector<2x3xf32>, + %base1 : index, %base2 : index) -> (tensor) { + %t = vector.transfer_write %vec, %A[%base1, %base2] {in_bounds = [true, true]} + : vector<2x3xf32>, tensor + return %t : tensor +} +