diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -855,6 +855,8 @@ PatternRewriter &rewriter) const override { if (xferOp.getVectorType().getRank() <= options.targetRank) return failure(); + if (isTensorOp(xferOp) && !options.lowerTensors) + return failure(); ScopedContext scope(rewriter, xferOp.getLoc()); auto insertOp = getInsertOp(xferOp); @@ -978,19 +980,25 @@ PatternRewriter &rewriter) const override { if (xferOp.getVectorType().getRank() <= options.targetRank) return failure(); + if (isTensorOp(xferOp) && !options.lowerTensors) + return failure(); ScopedContext scope(rewriter, xferOp.getLoc()); auto vec = getDataVector(xferOp); auto xferVecType = xferOp.getVectorType(); int64_t dimSize = xferVecType.getShape()[0]; + auto source = xferOp.source(); // memref or tensor to be written to. + auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); // Generate fully unrolled loop of transfer ops. for (int64_t i = 0; i < dimSize; ++i) { Value iv = std_constant_index(i); - generateInBoundsCheck( + auto updatedSource = generateInBoundsCheck( xferOp, iv, rewriter, unpackedDim(xferOp), - /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { + isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(), + /*inBoundsCase=*/ + [&](OpBuilder &b, Location loc) { ScopedContext scope(b, loc); // Indices for the new transfer op. @@ -1007,16 +1015,29 @@ auto newXferOp = vector_transfer_write( - Type(), extracted, xferOp.source(), xferIndices, + sourceType, extracted, source, xferIndices, AffineMapAttr::get(unpackedPermutationMap(xferOp, b)), Value(), inBoundsAttr) .op; maybeAssignMask(b, xferOp, newXferOp, i); + + return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value(); + }, + /*outOfBoundsCase=*/ + [&](OpBuilder &b, Location loc) { + return isTensorOp(xferOp) ? source : Value(); }); + + if (isTensorOp(xferOp)) + source = updatedSource; } - rewriter.eraseOp(xferOp); + if (isTensorOp(xferOp)) + rewriter.replaceOp(xferOp, source); + else + rewriter.eraseOp(xferOp); + return success(); } }; diff --git a/mlir/test/Conversion/VectorToSCF/unrolled-tensor-transfer-ops.mlir b/mlir/test/Conversion/VectorToSCF/unrolled-tensor-transfer-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/VectorToSCF/unrolled-tensor-transfer-ops.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true lower-tensors=true' -split-input-file -allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: func @transfer_read_2d( +// CHECK: %[[V_INIT:.*]] = constant dense<-4.200000e+01> : vector<4x9xf32> +// CHECK: %[[V0:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} {in_bounds = [true]} : tensor, vector<9xf32> +// CHECK: %[[I0:.*]] = vector.insert %[[V0]], %[[V_INIT]] [0] : vector<9xf32> into vector<4x9xf32> +// CHECK: %[[V1:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} {in_bounds = [true]} : tensor, vector<9xf32> +// CHECK: %[[I1:.*]] = vector.insert %[[V1]], %[[I0]] [1] : vector<9xf32> into vector<4x9xf32> +// CHECK: %[[V2:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} {in_bounds = [true]} : tensor, vector<9xf32> +// CHECK: %[[I2:.*]] = vector.insert %[[V2]], %[[I1]] [2] : vector<9xf32> into vector<4x9xf32> +// CHECK: %[[V3:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} {in_bounds = [true]} : tensor, vector<9xf32> +// CHECK: %[[I3:.*]] = vector.insert %[[V3]], %[[I2]] [3] : vector<9xf32> into vector<4x9xf32> +// CHECK: return %[[I3]] : vector<4x9xf32> +func @transfer_read_2d(%A : tensor, %base1 : index, %base2 : index) + -> (vector<4x9xf32>){ + %p = constant -42.0: f32 + %f = vector.transfer_read %A[%base1, %base2], %p {in_bounds = [true, true]} + : tensor, vector<4x9xf32> + return %f : vector<4x9xf32> +} + +// ----- + +// CHECK-LABEL: func @transfer_write_2d( +// CHECK: %[[V0:.*]] = vector.extract %{{.*}}[0] : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.transfer_write %[[V0]], %{{.*}}[{{.*}}] {in_bounds = [true]} : vector<3xf32>, tensor +// CHECK: %[[V1:.*]] = vector.extract %{{.*}}[1] : vector<2x3xf32> +// CHECK: %[[T1:.*]] = vector.transfer_write %[[V1]], %[[T0]][{{.*}}] {in_bounds = [true]} : vector<3xf32>, tensor +// CHECK: return %[[T1]] : tensor +func @transfer_write_2d(%A : tensor, %vec : vector<2x3xf32>, + %base1 : index, %base2 : index) -> (tensor) { + %t = vector.transfer_write %vec, %A[%base1, %base2] {in_bounds = [true, true]} + : vector<2x3xf32>, tensor + return %t : tensor +} +