diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1770,7 +1770,26 @@ return success(); } +static bool isStepIndexArray(ArrayAttr idxArr, int64_t begin, int64_t width) { + int64_t expected = begin; + return idxArr.size() == width && + llvm::all_of(idxArr.getAsValueRange(), + [&expected](auto attr) { + return attr.getZExtValue() == expected++; + }); +} + OpFoldResult vector::ShuffleOp::fold(ArrayRef operands) { + // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1 + if (!getV1VectorType().isScalable() && + isStepIndexArray(getMask(), 0, getV1VectorType().getDimSize(0))) + return getV1(); + // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2 + if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() && + isStepIndexArray(getMask(), getV1VectorType().getDimSize(0), + getV2VectorType().getDimSize(0))) + return getV2(); + Attribute lhs = operands.front(), rhs = operands.back(); if (!lhs || !rhs) return {}; diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -425,8 +425,7 @@ // CHECK-LABEL: @shuffle_1D_direct( // CHECK-SAME: %[[A:.*]]: vector<2xf32>, // CHECK-SAME: %[[B:.*]]: vector<2xf32>) -// CHECK: %[[s:.*]] = llvm.shufflevector %[[A]], %[[B]] [0, 1] : vector<2xf32>, vector<2xf32> -// CHECK: return %[[s]] : vector<2xf32> +// CHECK: return %[[A:.*]]: vector<2xf32> // ----- @@ -437,11 +436,7 @@ // CHECK-LABEL: @shuffle_1D_index_direct( // CHECK-SAME: %[[A:.*]]: vector<2xindex>, // CHECK-SAME: %[[B:.*]]: vector<2xindex>) -// CHECK-DAG: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2xindex> to vector<2xi64> -// CHECK-DAG: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2xindex> to vector<2xi64> -// CHECK: %[[T2:.*]] = llvm.shufflevector %[[T0]], %[[T1]] [0, 1] : vector<2xi64>, vector<2xi64> -// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<2xi64> to vector<2xindex> -// CHECK: return %[[T3]] : vector<2xindex> +// CHECK: return %[[A:.*]]: vector<2xindex> // ----- diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1276,3 +1276,47 @@ %shuffle = vector.shuffle %v0, %v1 [3, 2, 5, 1] : vector<3xi32>, vector<3xi32> return %shuffle : vector<4xi32> } + +// CHECK-LABEL: func @shuffle_fold1 +// CHECK: %arg0 : vector<4xi32> +func @shuffle_fold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<4xi32> { + %shuffle = vector.shuffle %v0, %v1 [0, 1, 2, 3] : vector<4xi32>, vector<2xi32> + return %shuffle : vector<4xi32> +} + +// CHECK-LABEL: func @shuffle_fold2 +// CHECK: %arg1 : vector<2xi32> +func @shuffle_fold2(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<2xi32> { + %shuffle = vector.shuffle %v0, %v1 [4, 5] : vector<4xi32>, vector<2xi32> + return %shuffle : vector<2xi32> +} + +// CHECK-LABEL: func @shuffle_fold3 +// CHECK: return %arg0 : vector<4x5x6xi32> +func @shuffle_fold3(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> vector<4x5x6xi32> { + %shuffle = vector.shuffle %v0, %v1 [0, 1, 2, 3] : vector<4x5x6xi32>, vector<2x5x6xi32> + return %shuffle : vector<4x5x6xi32> +} + +// CHECK-LABEL: func @shuffle_fold4 +// CHECK: return %arg1 : vector<2x5x6xi32> +func @shuffle_fold4(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> vector<2x5x6xi32> { + %shuffle = vector.shuffle %v0, %v1 [4, 5] : vector<4x5x6xi32>, vector<2x5x6xi32> + return %shuffle : vector<2x5x6xi32> +} + +// CHECK-LABEL: func @shuffle_nofold1 +// CHECK: %[[V:.+]] = vector.shuffle %arg0, %arg1 [0, 1, 2, 3, 4] : vector<4xi32>, vector<2xi32> +// CHECK: return %[[V]] +func @shuffle_nofold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<5xi32> { + %shuffle = vector.shuffle %v0, %v1 [0, 1, 2, 3, 4] : vector<4xi32>, vector<2xi32> + return %shuffle : vector<5xi32> +} + +// CHECK-LABEL: func @shuffle_nofold2 +// CHECK: %[[V:.+]] = vector.shuffle %arg0, %arg1 [0, 1, 2, 3] : vector<[4]xi32>, vector<[2]xi32> +// CHECK: return %[[V]] +func @shuffle_nofold2(%v0 : vector<[4]xi32>, %v1 : vector<[2]xi32>) -> vector<4xi32> { + %shuffle = vector.shuffle %v0, %v1 [0, 1, 2, 3] : vector<[4]xi32>, vector<[2]xi32> + return %shuffle : vector<4xi32> +}