diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -314,15 +314,18 @@ /// the VectorType into the MemRefType. /// /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> -static MemRefType unpackOneDim(MemRefType type) { +static FailureOr unpackOneDim(MemRefType type) { auto vectorType = dyn_cast(type.getElementType()); + // Vectors with leading scalable dims are not supported. + // It may be possible to support these in future by using dynamic memref dims. + if (vectorType.getScalableDims().front()) + return failure(); auto memrefShape = type.getShape(); SmallVector newMemrefShape; newMemrefShape.append(memrefShape.begin(), memrefShape.end()); newMemrefShape.push_back(vectorType.getDimSize(0)); return MemRefType::get(newMemrefShape, - VectorType::get(vectorType.getShape().drop_front(), - vectorType.getElementType())); + VectorType::Builder(vectorType).dropDim(0)); } /// Given a transfer op, find the memref from which the mask is loaded. This @@ -542,6 +545,10 @@ return failure(); if (xferOp.getVectorType().getRank() <= options.targetRank) return failure(); + // Currently the unpacking of the leading dimension into the memref is not + // supported for scalable dimensions. + if (xferOp.getVectorType().getScalableDims().front()) + return failure(); if (isTensorOp(xferOp) && !options.lowerTensors) return failure(); // Transfer ops that modify the element type are not supported atm. @@ -866,8 +873,11 @@ auto dataBuffer = Strategy::getBuffer(xferOp); auto dataBufferType = dyn_cast(dataBuffer.getType()); auto castedDataType = unpackOneDim(dataBufferType); + if (failed(castedDataType)) + return failure(); + auto castedDataBuffer = - locB.create(castedDataType, dataBuffer); + locB.create(*castedDataType, dataBuffer); // If the xferOp has a mask: Find and cast mask buffer. Value castedMaskBuffer; @@ -882,7 +892,9 @@ // be broadcasted.) castedMaskBuffer = maskBuffer; } else { - auto castedMaskType = unpackOneDim(maskBufferType); + // It's safe to assume the mask buffer can be unpacked if the data + // buffer was unpacked. + auto castedMaskType = *unpackOneDim(maskBufferType); castedMaskBuffer = locB.create(castedMaskType, maskBuffer); } @@ -891,7 +903,7 @@ // Loop bounds and step. auto lb = locB.create(0); auto ub = locB.create( - castedDataType.getDimSize(castedDataType.getRank() - 1)); + castedDataType->getDimSize(castedDataType->getRank() - 1)); auto step = locB.create(1); // TransferWriteOps that operate on tensors return the modified tensor and // require a loop state. @@ -1074,8 +1086,14 @@ auto vec = getResultVector(xferOp, rewriter); auto vecType = dyn_cast(vec.getType()); auto xferVecType = xferOp.getVectorType(); - auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(), - xferVecType.getElementType()); + + if (xferVecType.getScalableDims()[0]) { + // Cannot unroll a scalable dimension at compile time. + return failure(); + } + + VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0); + int64_t dimSize = xferVecType.getShape()[0]; // Generate fully unrolled loop of transfer ops. diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir --- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir @@ -635,3 +635,106 @@ // CHECK: vector.print // CHECK: return // CHECK: } + +// ----- + +func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[4]xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %dim = memref.dim %arg0, %c1 : memref<3x?xf32> + %mask = vector.create_mask %c1, %dim : vector<3x[4]xi1> + %read = vector.transfer_read %arg0[%c0, %c0], %cst, %mask {in_bounds = [true, true]} : memref<3x?xf32>, vector<3x[4]xf32> + return %read : vector<3x[4]xf32> +} +// CHECK-LABEL: func.func @transfer_read_array_of_scalable( +// CHECK-SAME: %[[ARG:.*]]: memref<3x?xf32>) -> vector<3x[4]xf32> { +// CHECK: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref> +// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref> +// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<3x?xf32> +// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[DIM_SIZE]] : vector<3x[4]xi1> +// CHECK: memref.store %[[MASK]], %[[ALLOCA_MASK]][] : memref> +// CHECK: %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref> to memref<3xvector<[4]xf32>> +// CHECK: %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref> to memref<3xvector<[4]xi1>> +// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] { +// CHECK: %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>> +// CHECK: %[[READ_SLICE:.*]] = vector.transfer_read %[[ARG]]{{\[}}%[[VAL_11]], %[[C0]]], %[[PADDING]], %[[MASK_SLICE]] {in_bounds = [true]} : memref<3x?xf32>, vector<[4]xf32> +// CHECK: memref.store %[[READ_SLICE]], %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>> +// CHECK: } +// CHECK: %[[RESULT:.*]] = memref.load %[[ALLOCA_VEC]][] : memref> +// CHECK: return %[[RESULT]] : vector<3x[4]xf32> +// CHECK: } + +// ----- + +func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memref<3x?xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %dim = memref.dim %arg0, %c1 : memref<3x?xf32> + %mask = vector.create_mask %c1, %dim : vector<3x[4]xi1> + vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} : vector<3x[4]xf32>, memref<3x?xf32> + return +} +// CHECK-LABEL: func.func @transfer_write_array_of_scalable( +// CHECK-SAME: %[[VEC:.*]]: vector<3x[4]xf32>, +// CHECK-SAME: %[[MEMREF:.*]]: memref<3x?xf32>) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref> +// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref> +// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[MEMREF]], %[[C1]] : memref<3x?xf32> +// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[DIM_SIZE]] : vector<3x[4]xi1> +// CHECK: memref.store %[[MASK]], %[[ALLOCA_MASK]][] : memref> +// CHECK: memref.store %[[VEC]], %[[ALLOCA_VEC]][] : memref> +// CHECK: %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref> to memref<3xvector<[4]xf32>> +// CHECK: %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref> to memref<3xvector<[4]xi1>> +// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] { +// CHECK: %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>> +// CHECK: %[[VECTOR_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>> +// CHECK: vector.transfer_write %[[MASK_SLICE]], %[[MEMREF]]{{\[}}%[[VAL_11]], %[[C0]]], %[[VECTOR_SLICE]] {in_bounds = [true]} : vector<[4]xf32>, memref<3x?xf32> +// CHECK: } +// CHECK: return +// CHECK: } + +// ----- + +/// The following two tests currently cannot be lowered via unpacking the leading dim since it is scalable. +/// It may be possible to special case this via a dynamic dim in future. + +func.func @cannot_lower_transfer_write_with_leading_scalable(%vec: vector<[4]x4xf32>, %arg0: memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.000000e+00 : f32 + %dim = memref.dim %arg0, %c0 : memref + %mask = vector.create_mask %dim, %c4 : vector<[4]x4xi1> + vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[4]x4xf32>, memref + return +} +// CHECK-LABEL: func.func @cannot_lower_transfer_write_with_leading_scalable( +// CHECK-SAME: %[[VEC:.*]]: vector<[4]x4xf32>, +// CHECK-SAME: %[[MEMREF:.*]]: memref) +// CHECK: vector.transfer_write %[[VEC]], %[[MEMREF]][%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : vector<[4]x4xf32>, memref + +// ----- + +func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref) -> vector<[4]x4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.000000e+00 : f32 + %dim = memref.dim %arg0, %c0 : memref + %mask = vector.create_mask %dim, %c4 : vector<[4]x4xi1> + %read = vector.transfer_read %arg0[%c0, %c0], %cst, %mask {in_bounds = [true, true]} : memref, vector<[4]x4xf32> + return %read : vector<[4]x4xf32> +} +// CHECK-LABEL: func.func @cannot_lower_transfer_read_with_leading_scalable( +// CHECK-SAME: %[[MEMREF:.*]]: memref) +// CHECK: %{{.*}} = vector.transfer_read %[[MEMREF]][%{{.*}}, %{{.*}}], %{{.*}}, %{{.*}} {in_bounds = [true, true]} : memref, vector<[4]x4xf32> + +