diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -98,6 +98,7 @@ static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc, Value value) { if (hasRetVal) { + assert(value && "Expected non-empty value"); builder.create(loc, value); } else { builder.create(loc); @@ -238,6 +239,16 @@ newXferOp->setAttr(kPassLabel, builder.getUnitAttr()); } +/// Return true if this TensorWriteOp operates on a source tensor. +static bool isTensorOp(TransferWriteOp xferOp) { + if (xferOp->getNumResults() > 0) { + assert(xferOp.getShapedType().isa() && + "Expected that TransferWriteOp with result has tensor source"); + return true; + } + return false; +} + namespace lowering_n_d { /// Helper data structure for data and mask buffers. @@ -360,8 +371,8 @@ /// Note: The `mask` operand is set in TransferOpConversion. static TransferReadOp rewriteOp(OpBuilder &builder, VectorTransferToSCFOptions options, - TransferReadOp xferOp, Value buffer, - Value iv) { + TransferReadOp xferOp, Value buffer, Value iv, + ValueRange /*loopState*/) { SmallVector storeIndices; getBufferIndices(xferOp, storeIndices); storeIndices.push_back(iv); @@ -389,9 +400,9 @@ /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write /// padding value to the temporary buffer. - static void handleOutOfBoundsDim(OpBuilder & /*builder*/, - TransferReadOp xferOp, Value buffer, - Value iv) { + static Value handleOutOfBoundsDim(OpBuilder & /*builder*/, + TransferReadOp xferOp, Value buffer, + Value iv, ValueRange /*loopState*/) { SmallVector storeIndices; getBufferIndices(xferOp, storeIndices); storeIndices.push_back(iv); @@ -400,13 +411,19 @@ auto vecType = bufferType.getElementType().dyn_cast(); auto vec = std_splat(vecType, xferOp.padding()); memref_store(vec, buffer, storeIndices); + + return Value(); } /// Cleanup after rewriting the op. - static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp) { + static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp, + scf::ForOp /*forOp*/) { rewriter.eraseOp(getStoreOp(xferOp)); rewriter.eraseOp(xferOp); } + + /// Return the initial loop state for the generated scf.for loop. + static Value initialLoopState(TransferReadOp xferOp) { return Value(); } }; /// Codegen strategy for vector TransferWriteOp. @@ -445,7 +462,7 @@ static TransferWriteOp rewriteOp(OpBuilder &builder, VectorTransferToSCFOptions options, TransferWriteOp xferOp, Value buffer, - Value iv) { + Value iv, ValueRange loopState) { SmallVector loadIndices; getBufferIndices(xferOp, loadIndices); loadIndices.push_back(iv); @@ -455,8 +472,10 @@ auto vec = memref_load(buffer, loadIndices); auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr()); + auto source = loopState.empty() ? xferOp.source() : loopState[0]; + Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); auto newXfer = vector_transfer_write( - Type(), vec, xferOp.source(), xferIndices, + type, vec, source, xferIndices, AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), Value(), inBoundsAttr); @@ -466,12 +485,26 @@ } /// Handle out-of-bounds accesses on the to-be-unpacked dimension. - static void handleOutOfBoundsDim(OpBuilder &builder, TransferWriteOp xferOp, - Value buffer, Value iv) {} + static Value handleOutOfBoundsDim(OpBuilder &builder, TransferWriteOp xferOp, + Value buffer, Value iv, + ValueRange loopState) { + return isTensorOp(xferOp) ? loopState[0] : Value(); + } /// Cleanup after rewriting the op. - static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) { - rewriter.eraseOp(xferOp); + static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp, + scf::ForOp forOp) { + if (isTensorOp(xferOp)) { + assert(forOp->getNumResults() == 1 && "Expected one for loop result"); + rewriter.replaceOp(xferOp, forOp->getResult(0)); + } else { + rewriter.eraseOp(xferOp); + } + } + + /// Return the initial loop state for the generated scf.for loop. + static Value initialLoopState(TransferWriteOp xferOp) { + return isTensorOp(xferOp) ? xferOp.source() : Value(); } }; @@ -599,6 +632,18 @@ /// corresponding Strategy. If the to-be-unpacked dimension can be /// out-of-bounds, generate an if-check and handle both cases separately. /// 3. Clean up according to the corresponding Strategy. +/// +/// Note: If the transfer op is a TransferWriteOp and operates on a tensor +/// source (as opposed to a memref source), then each iteration of the generated +/// scf.for loop yields the new tensor value. E.g.: +/// ``` +/// %result = scf.for i = 0 to 5 { +/// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>> +/// %1 = vector.transfer_write %0, %source[...] +/// : vector<4x3xf32>, tensor<5x4x3xf32> +/// scf.yield %1 : tensor<5x4x3xf32> +/// } +/// ``` template struct TransferOpConversion : public VectorToSCFPattern { using VectorToSCFPattern::VectorToSCFPattern; @@ -641,19 +686,26 @@ castedDataType.getDimSize(castedDataType.getRank() - 1)) .value; auto step = std_constant_index(1).value; + // TransferWriteOps that operate on tensors return the modified tensor and + // require a loop state. + auto loopState = Strategy::initialLoopState(xferOp); // Generate for loop. - rewriter.create( - xferOp.getLoc(), lb, ub, step, ValueRange(), - [&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) { + auto result = rewriter.create( + xferOp.getLoc(), lb, ub, step, + loopState ? ValueRange(loopState) : ValueRange(), + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { ScopedContext scope(b, loc); - generateInBoundsCheck( + Type stateType = loopState.empty() ? Type() : loopState[0].getType(); + + auto result = generateInBoundsCheck( xferOp, iv, b, unpackedDim(xferOp), + stateType ? TypeRange(stateType) : TypeRange(), /*inBoundsCase=*/ [&](OpBuilder &b, Location /*loc*/) { // Create new transfer op. OpTy newXfer = Strategy::rewriteOp( - b, this->options, xferOp, castedDataBuffer, iv); + b, this->options, xferOp, castedDataBuffer, iv, loopState); // If old transfer op has a mask: Set mask on new transfer op. // Special case: If the mask of the old transfer op is 1D and @@ -676,16 +728,19 @@ rewriter.updateRootInPlace( newXfer, [&]() { newXfer.maskMutable().assign(mask); }); } + + return loopState.empty() ? Value() : newXfer->getResult(0); }, /*outOfBoundsCase=*/ [&](OpBuilder &b, Location /*loc*/) { - Strategy::handleOutOfBoundsDim(b, xferOp, - castedDataBuffer, iv); + return Strategy::handleOutOfBoundsDim( + b, xferOp, castedDataBuffer, iv, loopState); }); - b.create(loc); + + maybeYieldValue(!loopState.empty(), b, loc, result); }); - Strategy::cleanup(rewriter, xferOp); + Strategy::cleanup(rewriter, xferOp, result); return success(); } }; diff --git a/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir b/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s -convert-vector-to-scf -split-input-file -allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: func @transfer_read_2d( +// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref> +// CHECK: %[[CASTED:.*]] = vector.type_cast %[[ALLOC]] : memref> to memref<4xvector<9xf32>> +// CHECK: scf.for {{.*}} { +// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %cst {in_bounds = [true]} : tensor, vector<9xf32> +// CHECK: memref.store %[[READ]], %[[CASTED]][%{{.*}}] : memref<4xvector<9xf32>> +// CHECK: } +// CHECK: %[[LOADED:.*]] = memref.load %[[ALLOC]][] : memref> +// CHECK: return %[[LOADED]] : vector<4x9xf32> +func @transfer_read_2d(%A : tensor, %base1 : index, %base2 : index) + -> (vector<4x9xf32>){ + %p = constant -42.0: f32 + %f = vector.transfer_read %A[%base1, %base2], %p {in_bounds = [true, true]} + : tensor, vector<4x9xf32> + return %f : vector<4x9xf32> +} + +// ----- + +// CHECK-LABEL: func @transfer_write_2d( +// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref> +// CHECK: memref.store {{.*}}, %[[ALLOC]][] : memref> +// CHECK: %[[CASTED:.*]] = vector.type_cast %[[ALLOC]] : memref> to memref<2xvector<3xf32>> +// CHECK: %[[RESULT:.*]] = scf.for {{.*}} iter_args(%[[STATE:.*]] = %{{.*}}) -> (tensor) { +// CHECK: %[[LOADED:.*]] = memref.load %[[CASTED]][%{{.*}}] : memref<2xvector<3xf32>> +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[LOADED]], %[[STATE]][{{.*}}] {in_bounds = [true]} : vector<3xf32>, tensor +// CHECK: scf.yield %[[WRITE]] : tensor +// CHECK: } +// CHECK: return %[[RESULT]] : tensor +func @transfer_write_2d(%A : tensor, %vec : vector<2x3xf32>, + %base1 : index, %base2 : index) -> (tensor) { + %t = vector.transfer_write %vec, %A[%base1, %base2] {in_bounds = [true, true]} + : vector<2x3xf32>, tensor + return %t : tensor +} +