diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -316,13 +316,14 @@ /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> static MemRefType unpackOneDim(MemRefType type) { auto vectorType = dyn_cast(type.getElementType()); + assert(!vectorType.getScalableDims().front() && + "Cannot unpack scalable dim into memref"); auto memrefShape = type.getShape(); SmallVector newMemrefShape; newMemrefShape.append(memrefShape.begin(), memrefShape.end()); newMemrefShape.push_back(vectorType.getDimSize(0)); return MemRefType::get(newMemrefShape, - VectorType::get(vectorType.getShape().drop_front(), - vectorType.getElementType())); + VectorType::Builder(vectorType).dropDim(0)); } /// Given a transfer op, find the memref from which the mask is loaded. This @@ -1074,8 +1075,14 @@ auto vec = getResultVector(xferOp, rewriter); auto vecType = dyn_cast(vec.getType()); auto xferVecType = xferOp.getVectorType(); - auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(), - xferVecType.getElementType()); + + if (xferVecType.getScalableDims().front()) { + // Cannot unroll a scalable dimension at compile time. + return failure(); + } + + VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0); + int64_t dimSize = xferVecType.getShape()[0]; // Generate fully unrolled loop of transfer ops. diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -58,13 +58,15 @@ return rewriter.notifyMatchFailure( op, "0-D and 1-D vectors are handled separately"); + if (dstType.getScalableDims().front()) + return rewriter.notifyMatchFailure( + op, "Cannot unroll leading scalable dim in dstType"); + auto loc = op.getLoc(); - auto eltType = dstType.getElementType(); int64_t dim = dstType.getDimSize(0); Value idx = op.getOperand(0); - VectorType lowType = - VectorType::get(dstType.getShape().drop_front(), eltType); + VectorType lowType = VectorType::Builder(dstType).dropDim(0); Value trueVal = rewriter.create( loc, lowType, op.getOperands().drop_front()); Value falseVal = rewriter.create( diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -434,7 +434,7 @@ vectorShape.end()); for (unsigned i : broadcastedDims) unbroadcastedVectorShape[i] = 1; - VectorType unbroadcastedVectorType = VectorType::get( + VectorType unbroadcastedVectorType = read.getVectorType().cloneWith( unbroadcastedVectorShape, read.getVectorType().getElementType()); // `vector.load` supports vector types as memref's elements only when the diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1743,6 +1743,28 @@ // ----- +// CHECK-LABEL: func @transfer_read_1d_scalable_mask +// CHECK: %[[passtru:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32> +// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %[[passtru]] {alignment = 4 : i32} : (!llvm.ptr, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32> +// CHECK: return %[[r]] : vector<[4]xf32> +func.func @transfer_read_1d_scalable_mask(%arg0: memref<1x?xf32>, %mask: vector<[4]xi1>) -> vector<[4]xf32> { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + %vec = vector.transfer_read %arg0[%c0, %c0], %pad, %mask {in_bounds = [true]} : memref<1x?xf32>, vector<[4]xf32> + return %vec : vector<[4]xf32> +} + +// ----- +// CHECK-LABEL: func @transfer_write_1d_scalable_mask +// CHECK: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into !llvm.ptr +func.func @transfer_write_1d_scalable_mask(%arg0: memref<1x?xf32>, %vec: vector<[4]xf32>, %mask: vector<[4]xi1>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true]} : vector<[4]xf32>, memref<1x?xf32> + return +} + +// ----- + func.func @genbool_0d_f() -> vector { %0 = vector.constant_mask [0] : vector return %0 : vector @@ -1847,6 +1869,28 @@ // ----- +// CHECK-LABEL: func.func @create_mask_2d_trailing_scalable( +// CHECK-SAME: %[[arg:.*]]: index) -> vector<1x[4]xi1> { +// CHECK: %[[zero_mask2d:.*]] = arith.constant dense : vector<1x[4]xi1> +// CHECK: %[[zero_llvm_mask2d:.*]] = builtin.unrealized_conversion_cast %[[zero_mask2d]] : vector<1x[4]xi1> to !llvm.array<1 x vector<[4]xi1>> +// CHECK: %[[indices:.*]] = llvm.intr.experimental.stepvector : vector<[4]xi32> +// CHECK: %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32 +// CHECK: %[[undef:.*]] = llvm.mlir.undef : vector<[4]xi32> +// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[arg_vec:.*]] = llvm.insertelement %[[arg_i32]], %[[undef]]{{\[}}%[[c0]] : i32] : vector<[4]xi32> +// CHECK: %[[splat_arg:.*]] = llvm.shufflevector %[[arg_vec]], %[[undef]] [0, 0, 0, 0] : vector<[4]xi32> +// CHECK: %[[llvm_mask1d:.*]] = arith.cmpi slt, %[[indices]], %[[splat_arg]] : vector<[4]xi32> +// CHECK: %[[llvm_mask2d:.*]] = llvm.insertvalue %[[llvm_mask1d]], %[[zero_llvm_mask2d]][0] : !llvm.array<1 x vector<[4]xi1>> +// CHECK: %[[mask2d:.*]] = builtin.unrealized_conversion_cast %[[llvm_mask2d]] : !llvm.array<1 x vector<[4]xi1>> to vector<1x[4]xi1> +// CHECK: return %[[mask2d]] : vector<1x[4]xi1> +func.func @create_mask_2d_trailing_scalable(%a: index) -> vector<1x[4]xi1> { + %c1 = arith.constant 1 : index + %mask = vector.create_mask %c1, %a : vector<1x[4]xi1> + return %mask : vector<1x[4]xi1> +} + +// ----- + func.func @transpose_0d(%arg0: vector) -> vector { %0 = vector.transpose %arg0, [] : vector to vector return %0 : vector diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir --- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir @@ -635,3 +635,69 @@ // CHECK: vector.print // CHECK: return // CHECK: } + +// ----- + +func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[4]xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %dim = memref.dim %arg0, %c1 : memref<3x?xf32> + %mask = vector.create_mask %c1, %dim : vector<3x[4]xi1> + %read = vector.transfer_read %arg0[%c0, %c0], %cst, %mask {in_bounds = [true, true]} : memref<3x?xf32>, vector<3x[4]xf32> + return %read : vector<3x[4]xf32> +} +// CHECK-LABEL: func.func @transfer_read_array_of_scalable( +// CHECK-SAME: %[[ARG:.*]]: memref<3x?xf32>) -> vector<3x[4]xf32> { +// CHECK: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref> +// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref> +// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<3x?xf32> +// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[DIM_SIZE]] : vector<3x[4]xi1> +// CHECK: memref.store %[[MASK]], %[[ALLOCA_MASK]][] : memref> +// CHECK: %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref> to memref<3xvector<[4]xf32>> +// CHECK: %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref> to memref<3xvector<[4]xi1>> +// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] { +// CHECK: %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>> +// CHECK: %[[READ_SLICE:.*]] = vector.transfer_read %[[ARG]]{{\[}}%[[VAL_11]], %[[C0]]], %[[PADDING]], %[[MASK_SLICE]] {in_bounds = [true]} : memref<3x?xf32>, vector<[4]xf32> +// CHECK: memref.store %[[READ_SLICE]], %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>> +// CHECK: } +// CHECK: %[[RESULT:.*]] = memref.load %[[ALLOCA_VEC]][] : memref> +// CHECK: return %[[RESULT]] : vector<3x[4]xf32> +// CHECK: } + +// ----- + +func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memref<3x?xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %dim = memref.dim %arg0, %c1 : memref<3x?xf32> + %mask = vector.create_mask %c1, %dim : vector<3x[4]xi1> + vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} : vector<3x[4]xf32>, memref<3x?xf32> + return +} +// CHECK-LABEL: func.func @transfer_write_array_of_scalable( +// CHECK-SAME: %[[VEC:.*]]: vector<3x[4]xf32>, +// CHECK-SAME: %[[MEMREF:.*]]: memref<3x?xf32>) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref> +// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref> +// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[MEMREF]], %[[C1]] : memref<3x?xf32> +// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[DIM_SIZE]] : vector<3x[4]xi1> +// CHECK: memref.store %[[MASK]], %[[ALLOCA_MASK]][] : memref> +// CHECK: memref.store %[[VEC]], %[[ALLOCA_VEC]][] : memref> +// CHECK: %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref> to memref<3xvector<[4]xf32>> +// CHECK: %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref> to memref<3xvector<[4]xi1>> +// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] { +// CHECK: %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>> +// CHECK: %[[VECTOR_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>> +// CHECK: vector.transfer_write %[[MASK_SLICE]], %[[MEMREF]]{{\[}}%[[VAL_11]], %[[C0]]], %[[VECTOR_SLICE]] {in_bounds = [true]} : vector<[4]xf32>, memref<3x?xf32> +// CHECK: } +// CHECK: return +// CHECK: }