diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -523,7 +523,9 @@ "Target vector rank to which transfer ops should be lowered">, Option<"lowerPermutationMaps", "lower-permutation-maps", "bool", /*default=*/"false", "Replace permutation maps with vector " - "transposes/broadcasts before lowering transfer ops"> + "transposes/broadcasts before lowering transfer ops">, + Option<"lowerTensors", "lower-tensors", "bool", /*default=*/"false", + "Lower transfer ops that operate on tensors"> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h --- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h +++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h @@ -48,12 +48,18 @@ /// is reused and only a second vector.type_cast is added. struct VectorTransferToSCFOptions { - bool unroll = false; unsigned targetRank = 1; bool lowerPermutationMaps = false; + bool lowerTensors = false; + bool unroll = false; - VectorTransferToSCFOptions &setUnroll(bool u) { - unroll = u; + VectorTransferToSCFOptions &setLowerPermutationMaps(bool l) { + lowerPermutationMaps = l; + return *this; + } + + VectorTransferToSCFOptions &setLowerTensors(bool l) { + lowerTensors = l; return *this; } @@ -62,8 +68,8 @@ return *this; } - VectorTransferToSCFOptions &setLowerPermutationMaps(bool l) { - lowerPermutationMaps = l; + VectorTransferToSCFOptions &setUnroll(bool u) { + unroll = u; return *this; } }; diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -99,6 +99,7 @@ static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, Value value) { if (hasRetVal) { + assert(value && "Expected non-empty value"); b.create(loc, value); } else { b.create(loc); @@ -242,6 +243,19 @@ newXferOp->setAttr(kPassLabel, b.getUnitAttr()); } +/// Return true if this transfer op operates on a source tensor. +template +static bool isTensorOp(OpTy xferOp) { + if (xferOp.getShapedType().template isa()) { + if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) { + // TransferWriteOps on tensors have a result. + assert(xferOp->getNumResults() > 0); + } + return true; + } + return false; +} + namespace lowering_n_d { /// Helper data structure for data and mask buffers. @@ -365,8 +379,8 @@ /// Note: The `mask` operand is set in TransferOpConversion. static TransferReadOp rewriteOp(OpBuilder &b, VectorTransferToSCFOptions options, - TransferReadOp xferOp, Value buffer, - Value iv) { + TransferReadOp xferOp, Value buffer, Value iv, + ValueRange /*loopState*/) { SmallVector storeIndices; getBufferIndices(xferOp, storeIndices); storeIndices.push_back(iv); @@ -391,8 +405,9 @@ /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write /// padding value to the temporary buffer. - static void handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp, - Value buffer, Value iv) { + static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp, + Value buffer, Value iv, + ValueRange /*loopState*/) { SmallVector storeIndices; getBufferIndices(xferOp, storeIndices); storeIndices.push_back(iv); @@ -402,13 +417,19 @@ auto vecType = bufferType.getElementType().dyn_cast(); auto vec = b.create(loc, vecType, xferOp.padding()); b.create(loc, vec, buffer, storeIndices); + + return Value(); } /// Cleanup after rewriting the op. - static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp) { + static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp, + scf::ForOp /*forOp*/) { rewriter.eraseOp(getStoreOp(xferOp)); rewriter.eraseOp(xferOp); } + + /// Return the initial loop state for the generated scf.for loop. + static Value initialLoopState(TransferReadOp xferOp) { return Value(); } }; /// Codegen strategy for vector TransferWriteOp. @@ -447,7 +468,7 @@ static TransferWriteOp rewriteOp(OpBuilder &b, VectorTransferToSCFOptions options, TransferWriteOp xferOp, Value buffer, - Value iv) { + Value iv, ValueRange loopState) { SmallVector loadIndices; getBufferIndices(xferOp, loadIndices); loadIndices.push_back(iv); @@ -458,8 +479,10 @@ Location loc = xferOp.getLoc(); auto vec = b.create(loc, buffer, loadIndices); auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); + auto source = loopState.empty() ? xferOp.source() : loopState[0]; + Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); auto newXferOp = b.create( - loc, Type(), vec, xferOp.source(), xferIndices, + loc, type, vec, source, xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), inBoundsAttr); @@ -469,12 +492,26 @@ } /// Handle out-of-bounds accesses on the to-be-unpacked dimension. - static void handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp, - Value buffer, Value iv) {} + static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp, + Value buffer, Value iv, + ValueRange loopState) { + return isTensorOp(xferOp) ? loopState[0] : Value(); + } /// Cleanup after rewriting the op. - static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) { - rewriter.eraseOp(xferOp); + static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp, + scf::ForOp forOp) { + if (isTensorOp(xferOp)) { + assert(forOp->getNumResults() == 1 && "Expected one for loop result"); + rewriter.replaceOp(xferOp, forOp->getResult(0)); + } else { + rewriter.eraseOp(xferOp); + } + } + + /// Return the initial loop state for the generated scf.for loop. + static Value initialLoopState(TransferWriteOp xferOp) { + return isTensorOp(xferOp) ? xferOp.source() : Value(); } }; @@ -485,7 +522,7 @@ return failure(); if (xferOp.getVectorType().getRank() <= options.targetRank) return failure(); - if (xferOp.getShapedType().template isa()) + if (isTensorOp(xferOp) && !options.lowerTensors) return failure(); // Transfer ops that modify the element type are not supported atm. if (xferOp.getVectorType().getElementType() != @@ -610,6 +647,18 @@ /// corresponding Strategy. If the to-be-unpacked dimension can be /// out-of-bounds, generate an if-check and handle both cases separately. /// 3. Clean up according to the corresponding Strategy. +/// +/// Note: If the transfer op is a TransferWriteOp and operates on a tensor +/// source (as opposed to a memref source), then each iteration of the generated +/// scf.for loop yields the new tensor value. E.g.: +/// ``` +/// %result = scf.for i = 0 to 5 { +/// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>> +/// %1 = vector.transfer_write %0, %source[...] +/// : vector<4x3xf32>, tensor<5x4x3xf32> +/// scf.yield %1 : tensor<5x4x3xf32> +/// } +/// ``` template struct TransferOpConversion : public VectorToSCFPattern { using VectorToSCFPattern::VectorToSCFPattern; @@ -652,18 +701,24 @@ auto ub = locB.create( castedDataType.getDimSize(castedDataType.getRank() - 1)); auto step = locB.create(1); + // TransferWriteOps that operate on tensors return the modified tensor and + // require a loop state. + auto loopState = Strategy::initialLoopState(xferOp); // Generate for loop. - locB.create( - lb, ub, step, ValueRange(), - [&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) { - generateInBoundsCheck( + auto result = locB.create( + lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { + Type stateType = loopState.empty() ? Type() : loopState[0].getType(); + + auto result = generateInBoundsCheck( b, xferOp, iv, unpackedDim(xferOp), + stateType ? TypeRange(stateType) : TypeRange(), /*inBoundsCase=*/ [&](OpBuilder &b, Location loc) { // Create new transfer op. OpTy newXfer = Strategy::rewriteOp( - b, this->options, xferOp, castedDataBuffer, iv); + b, this->options, xferOp, castedDataBuffer, iv, loopState); // If old transfer op has a mask: Set mask on new transfer op. // Special case: If the mask of the old transfer op is 1D and @@ -687,16 +742,19 @@ rewriter.updateRootInPlace( newXfer, [&]() { newXfer.maskMutable().assign(mask); }); } + + return loopState.empty() ? Value() : newXfer->getResult(0); }, /*outOfBoundsCase=*/ [&](OpBuilder &b, Location /*loc*/) { - Strategy::handleOutOfBoundsDim(b, xferOp, - castedDataBuffer, iv); + return Strategy::handleOutOfBoundsDim( + b, xferOp, castedDataBuffer, iv, loopState); }); - b.create(loc); + + maybeYieldValue(b, loc, !loopState.empty(), result); }); - Strategy::cleanup(rewriter, xferOp); + Strategy::cleanup(rewriter, xferOp, result); return success(); } }; @@ -1184,6 +1242,7 @@ this->fullUnroll = options.unroll; this->targetRank = options.targetRank; this->lowerPermutationMaps = options.lowerPermutationMaps; + this->lowerTensors = options.lowerTensors; } void runOnFunction() override { @@ -1191,6 +1250,7 @@ options.unroll = fullUnroll; options.targetRank = targetRank; options.lowerPermutationMaps = lowerPermutationMaps; + options.lowerTensors = lowerTensors; // Lower permutation maps first. if (lowerPermutationMaps) { diff --git a/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir b/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s -convert-vector-to-scf='lower-tensors=true' -split-input-file -allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: func @transfer_read_2d( +// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref> +// CHECK: %[[CASTED:.*]] = vector.type_cast %[[ALLOC]] : memref> to memref<4xvector<9xf32>> +// CHECK: scf.for {{.*}} { +// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %cst {in_bounds = [true]} : tensor, vector<9xf32> +// CHECK: memref.store %[[READ]], %[[CASTED]][%{{.*}}] : memref<4xvector<9xf32>> +// CHECK: } +// CHECK: %[[LOADED:.*]] = memref.load %[[ALLOC]][] : memref> +// CHECK: return %[[LOADED]] : vector<4x9xf32> +func @transfer_read_2d(%A : tensor, %base1 : index, %base2 : index) + -> (vector<4x9xf32>){ + %p = constant -42.0: f32 + %f = vector.transfer_read %A[%base1, %base2], %p {in_bounds = [true, true]} + : tensor, vector<4x9xf32> + return %f : vector<4x9xf32> +} + +// ----- + +// CHECK-LABEL: func @transfer_write_2d( +// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref> +// CHECK: memref.store {{.*}}, %[[ALLOC]][] : memref> +// CHECK: %[[CASTED:.*]] = vector.type_cast %[[ALLOC]] : memref> to memref<2xvector<3xf32>> +// CHECK: %[[RESULT:.*]] = scf.for {{.*}} iter_args(%[[STATE:.*]] = %{{.*}}) -> (tensor) { +// CHECK: %[[LOADED:.*]] = memref.load %[[CASTED]][%{{.*}}] : memref<2xvector<3xf32>> +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[LOADED]], %[[STATE]][{{.*}}] {in_bounds = [true]} : vector<3xf32>, tensor +// CHECK: scf.yield %[[WRITE]] : tensor +// CHECK: } +// CHECK: return %[[RESULT]] : tensor +func @transfer_write_2d(%A : tensor, %vec : vector<2x3xf32>, + %base1 : index, %base2 : index) -> (tensor) { + %t = vector.transfer_write %vec, %A[%base1, %base2] {in_bounds = [true, true]} + : vector<2x3xf32>, tensor + return %t : tensor +} +