diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1151,7 +1151,8 @@ auto n = std::min(adaptor.getPosition().size(), vectorType.getRank()); inferredReturnTypes.push_back(VectorType::get( - vectorType.getShape().drop_front(n), vectorType.getElementType())); + vectorType.getShape().drop_front(n), vectorType.getElementType(), + vectorType.getScalableDims().drop_front(n))); } return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -416,6 +416,18 @@ } }; +/// Looks at elementwise operations on vectors with at least one leading +/// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>), +/// and cast aways the leading one dimensions (_plural_) and then broadcasts +/// the results. +/// +/// Example before: +/// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32> +/// Example after: +/// %2 = arith.mulf %0, %1 : vector<4x1xf32> +/// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32> +/// +/// Does support scalable vectors. class CastAwayElementwiseLeadingOneDim : public RewritePattern { public: CastAwayElementwiseLeadingOneDim(MLIRContext *context, diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir @@ -276,6 +276,30 @@ return %0: vector<1x1x4xf32> } +// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_scalar_scalable( +// CHECK-SAME: %[[S:.*]]: f32, +// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { +func.func @cast_away_insert_leading_one_dims_scalar_scalable(%s: f32, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[V]][0, 0] : vector<1x1x[4]xf32> +// CHECK: %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<[4]xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[4]xf32> to vector<1x1x[4]xf32> +// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32> + %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x[4]xf32> + return %0: vector<1x1x[4]xf32> +} + +// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim( +// CHECK-SAME: %[[S:.*]]: f32, +// CHECK-SAME: %[[V:.*]]: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> { +func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(%s: f32, %v: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> { +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[V]][0] : vector<1x[1]x4xf32> +// CHECK: %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0, 0] : f32 into vector<[1]x4xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[1]x4xf32> to vector<1x[1]x4xf32> +// CHECK: return %[[BCAST]] : vector<1x[1]x4xf32> + %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x[1]x4xf32> + return %0: vector<1x[1]x4xf32> +} + // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank1 // CHECK-SAME: (%[[S:.+]]: vector<4xf32>, %[[V:.+]]: vector<1x1x4xf32>) // CHECK: %[[BCAST:.+]] = vector.broadcast %[[S]] : vector<4xf32> to vector<1x1x4xf32> @@ -285,6 +309,16 @@ return %0: vector<1x1x4xf32> } +// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank1_scalable( +// CHECK-SAME: %[[S:.*]]: vector<[4]xf32>, +// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[S]] : vector<[4]xf32> to vector<1x1x[4]xf32> +// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32> +func.func @cast_away_insert_leading_one_dims_rank1_scalable(%s: vector<[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { + %0 = vector.insert %s, %v [0, 0] : vector<[4]xf32> into vector<1x1x[4]xf32> + return %0: vector<1x1x[4]xf32> +} + // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2 // CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x1x4xf32>) // CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32> @@ -295,6 +329,17 @@ return %0: vector<1x1x4xf32> } +// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank2_scalable( +// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>, +// CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<1x[4]xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : vector<[4]xf32> to vector<1x1x[4]xf32> +// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32> +func.func @cast_away_insert_leading_one_dims_rank2_scalable(%s: vector<1x[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { + %0 = vector.insert %s, %v [0] : vector<1x[4]xf32> into vector<1x1x[4]xf32> + return %0: vector<1x1x[4]xf32> +} + // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2_one_dest // CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x2x1x4xf32>) // CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<1x4xf32> @@ -307,6 +352,19 @@ return %0: vector<1x2x1x4xf32> } +// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable( +// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>, +// CHECK-SAME: %[[V:.*]]: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> { +// CHECK: %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<1x[4]xf32> +// CHECK: %[[EXTRACTV:.*]] = vector.extract %[[V]][0] : vector<1x2x1x[4]xf32> +// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<[4]xf32> into vector<2x1x[4]xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<2x1x[4]xf32> to vector<1x2x1x[4]xf32> +// CHECK: return %[[BCAST]] : vector<1x2x1x[4]xf32> +func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> { + %0 = vector.insert %s, %v [0, 1] : vector<1x[4]xf32> into vector<1x2x1x[4]xf32> + return %0: vector<1x2x1x[4]xf32> +} + // CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest // CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>) // CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32> @@ -317,6 +375,17 @@ return %0: vector<8x1x4xf32> } +// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable( +// CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>, +// CHECK-SAME: %[[V:.*]]: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> { +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<1x[4]xf32> +// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<[4]xf32> into vector<8x1x[4]xf32> +// CHECK: return %[[INSERT]] : vector<8x1x[4]xf32> +func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> { + %0 = vector.insert %s, %v [5] : vector<1x[4]xf32> into vector<8x1x[4]xf32> + return %0: vector<8x1x[4]xf32> +} + // CHECK-LABEL: func @cast_away_insert_leading_one_dims_one_two_dest // CHECK-SAME: (%[[S:.+]]: vector<1x8xi1>, %[[V:.+]]: vector<1x1x8x1x8xi1>) // CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<1x8xi1> @@ -328,3 +397,16 @@ %0 = vector.insert %s, %v [0, 0, 7] : vector<1x8xi1> into vector<1x1x8x1x8xi1> return %0: vector<1x1x8x1x8xi1> } + +// CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable( +// CHECK-SAME: %[[S:.*]]: vector<1x[8]xi1>, +// CHECK-SAME: %[[V:.*]]: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> { +// CHECK: %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<1x[8]xi1> +// CHECK: %[[EXTRACTV:.*]] = vector.extract %[[V]][0, 0] : vector<1x1x8x1x[8]xi1> +// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<[8]xi1> into vector<8x1x[8]xi1> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<8x1x[8]xi1> to vector<1x1x8x1x[8]xi1> +// CHECK: return %[[BCAST]] : vector<1x1x8x1x[8]xi1> +func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x[8]xi1>, %v: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> { + %0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1> + return %0: vector<1x1x8x1x[8]xi1> +} diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -34,6 +34,22 @@ return %1 : vector<2x[4]x1xf32> } +// CHECK-LABEL: func.func @cast_away_leading_one_dim( +// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<4x1xf32> +// CHECK: vector.broadcast %[[MUL]] : vector<4x1xf32> to vector<1x4x1xf32> +func.func @cast_away_leading_one_dim(%arg0: vector<1x4x1xf32>, %arg1: vector<1x4x1xf32>) -> vector<1x4x1xf32> { + %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32> + return %1: vector<1x4x1xf32> +} + +// CHECK-LABEL: func.func @cast_away_leading_one_dim_scalable( +// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<[4]x1xf32> +// CHECK: vector.broadcast %[[MUL]] : vector<[4]x1xf32> to vector<1x[4]x1xf32> +func.func @cast_away_leading_one_dim_scalable(%arg0: vector<1x[4]x1xf32>, %arg1: vector<1x[4]x1xf32>) -> vector<1x[4]x1xf32> { + %1 = arith.mulf %arg0, %arg1 : vector<1x[4]x1xf32> + return %1: vector<1x[4]x1xf32> +} + // CHECK-LABEL: func @add4x4 // CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>