diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4949,11 +4949,15 @@ Value vector, ArrayRef transp) { VectorType vt = llvm::cast(vector.getType()); SmallVector transposedShape(vt.getRank()); - for (unsigned i = 0; i < transp.size(); ++i) + SmallVector transposedScalableDims(vt.getRank()); + for (unsigned i = 0; i < transp.size(); ++i) { transposedShape[i] = vt.getShape()[transp[i]]; + transposedScalableDims[i] = vt.getScalableDims()[transp[i]]; + } result.addOperands(vector); - result.addTypes(VectorType::get(transposedShape, vt.getElementType())); + result.addTypes(VectorType::get(transposedShape, vt.getElementType(), + transposedScalableDims)); result.addAttribute(TransposeOp::getTranspAttrName(result.name), builder.getI64ArrayAttr(transp)); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -115,8 +115,11 @@ // Apply the reverse transpose to deduce the type of the transfer_read. ArrayRef originalShape = op.getVectorType().getShape(); SmallVector newVectorShape(originalShape.size()); + ArrayRef originalScalableDims = op.getVectorType().getScalableDims(); + SmallVector newScalableDims(originalShape.size()); for (const auto &pos : llvm::enumerate(permutation)) { newVectorShape[pos.value()] = originalShape[pos.index()]; + newScalableDims[pos.value()] = originalScalableDims[pos.index()]; } // Transpose in_bounds attribute. @@ -126,8 +129,8 @@ : ArrayAttr(); // Generate new transfer_read operation. - VectorType newReadType = - VectorType::get(newVectorShape, op.getVectorType().getElementType()); + VectorType newReadType = VectorType::get( + newVectorShape, op.getVectorType().getElementType(), newScalableDims); Value newRead = rewriter.create( op.getLoc(), newReadType, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), @@ -345,14 +348,16 @@ return success(); } - SmallVector newShape = llvm::to_vector<4>( + SmallVector newShape( originalVecType.getShape().take_back(reducedShapeRank)); + SmallVector newScalableDims( + originalVecType.getScalableDims().take_back(reducedShapeRank)); // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering. if (newShape.empty()) return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d"); - VectorType newReadType = - VectorType::get(newShape, originalVecType.getElementType()); + VectorType newReadType = VectorType::get( + newShape, originalVecType.getElementType(), newScalableDims); ArrayAttr newInBoundsAttr = op.getInBounds() ? rewriter.getArrayAttr( diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir @@ -1,12 +1,12 @@ // RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s -// CHECK-LABEL: func @lower_permutation_with_mask( +// CHECK-LABEL: func @lower_permutation_with_mask_fixed_width( // CHECK: %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32> // CHECK: %[[mask:.*]] = arith.constant dense<[true, false, true, false, true, true, true]> : vector<7xi1> // CHECK: %[[b:.*]] = vector.broadcast %[[mask]] : vector<7xi1> to vector<1x7xi1> // CHECK: %[[tp:.*]] = vector.transpose %[[b]], [1, 0] : vector<1x7xi1> to vector<7x1xi1> // CHECK: vector.transfer_write %[[vec]], %{{.*}}[%{{.*}}, %{{.*}}], %[[tp]] {in_bounds = [false, true]} : vector<7x1xf32>, memref -func.func @lower_permutation_with_mask(%A : memref, %base1 : index, +func.func @lower_permutation_with_mask_fixed_width(%A : memref, %base1 : index, %base2 : index) { %fn1 = arith.constant -2.0 : f32 %vf0 = vector.splat %fn1 : vector<7xf32> @@ -17,6 +17,30 @@ return } +// CHECK-LABEL: func.func @permutation_with_mask_scalable( +// CHECK-SAME: %[[ARG_0:.*]]: memref, +// CHECK-SAME: %[[IDX_1:.*]]: index, +// CHECK-SAME: %[[IDX_2:.*]]: index) -> vector<8x[4]x2xf32> { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[MASK:.*]] = vector.create_mask %[[IDX_2]], %[[IDX_1]] : vector<2x[4]xi1> +// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[ARG_0]]{{\[}}%[[C0]], %[[C0]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref, vector<2x[4]xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x[4]xf32> to vector<8x2x[4]xf32> +// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x[4]xf32> to vector<8x[4]x2xf32> +// CHECK: return %[[TRANSPOSE]] : vector<8x[4]x2xf32> +// CHECK: } +func.func @permutation_with_mask_scalable(%2: memref, %dim_1: index, %dim_2: index) -> (vector<8x[4]x2xf32>) { + + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + + %mask = vector.create_mask %dim_2, %dim_1 : vector<2x[4]xi1> + %1 = vector.transfer_read %2[%c0, %c0], %cst_0, %mask + {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>} + : memref, vector<8x[4]x2xf32> + return %1 : vector<8x[4]x2xf32> +} + transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): %f = transform.structured.match ops{["func.func"]} in %module_op