diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -866,7 +866,7 @@ PatternRewriter &rewriter) const override { if (xferOp.getVectorType().getRank() <= options.targetRank) return failure(); - if (xferOp.getShapedType().template isa()) + if (isTensorOp(xferOp) && !options.lowerTensors) return failure(); // Transfer ops that modify the element type are not supported atm. if (xferOp.getVectorType().getElementType() != @@ -988,7 +988,7 @@ PatternRewriter &rewriter) const override { if (xferOp.getVectorType().getRank() <= options.targetRank) return failure(); - if (xferOp.getShapedType().template isa()) + if (isTensorOp(xferOp) && !options.lowerTensors) return failure(); // Transfer ops that modify the element type are not supported atm. if (xferOp.getVectorType().getElementType() != @@ -998,15 +998,19 @@ auto vec = getDataVector(xferOp); auto xferVecType = xferOp.getVectorType(); int64_t dimSize = xferVecType.getShape()[0]; + auto source = xferOp.source(); // memref or tensor to be written to. + auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); // Generate fully unrolled loop of transfer ops. Location loc = xferOp.getLoc(); for (int64_t i = 0; i < dimSize; ++i) { Value iv = rewriter.create(loc, i); - generateInBoundsCheck( + auto updatedSource = generateInBoundsCheck( rewriter, xferOp, iv, unpackedDim(xferOp), - /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { + isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(), + /*inBoundsCase=*/ + [&](OpBuilder &b, Location loc) { // Indices for the new transfer op. SmallVector xferIndices; getXferIndices(b, xferOp, iv, xferIndices); @@ -1019,17 +1023,29 @@ auto extracted = b.create(loc, vec, extractionIndices); auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); - auto newXferOp = b.create( - loc, Type(), extracted, xferOp.source(), xferIndices, + loc, sourceType, extracted, source, xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), inBoundsAttr); maybeAssignMask(b, xferOp, newXferOp, i); + + return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value(); + }, + /*outOfBoundsCase=*/ + [&](OpBuilder &b, Location loc) { + return isTensorOp(xferOp) ? source : Value(); }); + + if (isTensorOp(xferOp)) + source = updatedSource; } - rewriter.eraseOp(xferOp); + if (isTensorOp(xferOp)) + rewriter.replaceOp(xferOp, source); + else + rewriter.eraseOp(xferOp); + return success(); } }; diff --git a/mlir/test/Conversion/VectorToSCF/unrolled-tensor-transfer-ops.mlir b/mlir/test/Conversion/VectorToSCF/unrolled-tensor-transfer-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/VectorToSCF/unrolled-tensor-transfer-ops.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true lower-tensors=true' -split-input-file -allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: func @transfer_read_2d( +// CHECK: %[[V_INIT:.*]] = constant dense<-4.200000e+01> : vector<4x9xf32> +// CHECK: %[[V0:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} {in_bounds = [true]} : tensor, vector<9xf32> +// CHECK: %[[I0:.*]] = vector.insert %[[V0]], %[[V_INIT]] [0] : vector<9xf32> into vector<4x9xf32> +// CHECK: %[[V1:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} {in_bounds = [true]} : tensor, vector<9xf32> +// CHECK: %[[I1:.*]] = vector.insert %[[V1]], %[[I0]] [1] : vector<9xf32> into vector<4x9xf32> +// CHECK: %[[V2:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} {in_bounds = [true]} : tensor, vector<9xf32> +// CHECK: %[[I2:.*]] = vector.insert %[[V2]], %[[I1]] [2] : vector<9xf32> into vector<4x9xf32> +// CHECK: %[[V3:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} {in_bounds = [true]} : tensor, vector<9xf32> +// CHECK: %[[I3:.*]] = vector.insert %[[V3]], %[[I2]] [3] : vector<9xf32> into vector<4x9xf32> +// CHECK: return %[[I3]] : vector<4x9xf32> +func @transfer_read_2d(%A : tensor, %base1 : index, %base2 : index) + -> (vector<4x9xf32>){ + %p = constant -42.0: f32 + %f = vector.transfer_read %A[%base1, %base2], %p {in_bounds = [true, true]} + : tensor, vector<4x9xf32> + return %f : vector<4x9xf32> +} + +// ----- + +// CHECK-LABEL: func @transfer_write_2d( +// CHECK: %[[V0:.*]] = vector.extract %{{.*}}[0] : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.transfer_write %[[V0]], %{{.*}}[{{.*}}] {in_bounds = [true]} : vector<3xf32>, tensor +// CHECK: %[[V1:.*]] = vector.extract %{{.*}}[1] : vector<2x3xf32> +// CHECK: %[[T1:.*]] = vector.transfer_write %[[V1]], %[[T0]][{{.*}}] {in_bounds = [true]} : vector<3xf32>, tensor +// CHECK: return %[[T1]] : tensor +func @transfer_write_2d(%A : tensor, %vec : vector<2x3xf32>, + %base1 : index, %base2 : index) -> (tensor) { + %t = vector.transfer_write %vec, %A[%base1, %base2] {in_bounds = [true, true]} + : vector<2x3xf32>, tensor + return %t : tensor +} +