diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -302,7 +302,7 @@ Results<(outs AnyType:$dest)> { let summary = "Multi-dimensional reduction operation"; let description = [{ - Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n) + Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n) using the given operation (add/mul/min/max for int/fp and and/or/xor for int only). @@ -380,7 +380,7 @@ PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]>, Arguments<(ins AnyType:$source)>, - Results<(outs AnyVector:$vector)> { + Results<(outs AnyVectorOfAnyRank:$vector)> { let summary = "broadcast operation"; let description = [{ Broadcasts the scalar or k-D vector value in the source operand diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -546,10 +546,27 @@ VectorType srcType = op.getSourceType().dyn_cast(); Type eltType = dstType.getElementType(); + // Scalar to any vector can use splat. + if (!srcType) { + rewriter.replaceOpWithNewOp(op, dstType, op.source()); + return success(); + } + // Determine rank of source and destination. - int64_t srcRank = srcType ? srcType.getRank() : 0; + int64_t srcRank = srcType.getRank(); int64_t dstRank = dstType.getRank(); + // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. + if (srcRank <= 1 && dstRank == 1) { + Value ext; + if (srcRank == 0) + ext = rewriter.create(loc, op.source()); + else + ext = rewriter.create(loc, op.source(), 0); + rewriter.replaceOpWithNewOp(op, dstType, ext); + return success(); + } + // Duplicate this rank. // For example: // %x = broadcast %y : k-D to n-D, k < n @@ -560,11 +577,6 @@ // %b = [%y,%y] : (n-1)-D // %x = [%b,%b,%b,%b] : n-D if (srcRank < dstRank) { - // Scalar to any vector can use splat. - if (srcRank == 0) { - rewriter.replaceOpWithNewOp(op, dstType, op.source()); - return success(); - } // Duplication. VectorType resType = VectorType::get(dstType.getShape().drop_front(), eltType); @@ -593,14 +605,6 @@ return success(); } - // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. - if (srcRank == 1) { - assert(m == 0); - Value ext = rewriter.create(loc, op.source(), 0); - rewriter.replaceOpWithNewOp(op, dstType, ext); - return success(); - } - // Any non-matching dimension forces a stretch along this rank. // For example: // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32> diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -35,6 +35,27 @@ // ----- +func @broadcast_vec0d_from_f32(%arg0: f32) -> vector { + %0 = vector.broadcast %arg0 : f32 to vector + return %0 : vector +} +// CHECK-LABEL: @broadcast_vec0d_from_f32 +// CHECK-SAME: %[[A:.*]]: f32) +// CHECK: %[[T0:.*]] = splat %[[A]] : vector +// CHECK: return %[[T0]] : vector + +// ----- + +func @broadcast_vec0d_from_vec0d(%arg0: vector) -> vector { + %0 = vector.broadcast %arg0 : vector to vector + return %0 : vector +} +// CHECK-LABEL: @broadcast_vec0d_from_vec0d( +// CHECK-SAME: %[[A:.*]]: vector) +// CHECK: return %[[A]] : vector + +// ----- + func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> { %0 = vector.broadcast %arg0 : f32 to vector<2xf32> return %0 : vector<2xf32> @@ -89,6 +110,26 @@ // ----- +func @broadcast_vec2d_from_vec0d(%arg0: vector) -> vector<3x2xf32> { + %0 = vector.broadcast %arg0 : vector to vector<3x2xf32> + return %0 : vector<3x2xf32> +} +// CHECK-LABEL: @broadcast_vec2d_from_vec0d( +// CHECK-SAME: %[[A:.*]]: vector) +// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector to vector<1xf32> +// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> +// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>> +// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32> +// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<2xf32> +// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>> +// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][1] : !llvm.array<3 x vector<2xf32>> +// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T6]], %[[T8]][2] : !llvm.array<3 x vector<2xf32>> +// CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T9]] : !llvm.array<3 x vector<2xf32>> to vector<3x2xf32> +// CHECK: return %[[T10]] : vector<3x2xf32> + +// ----- + func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> { %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32> return %0 : vector<3x2xf32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -28,6 +28,33 @@ return } +func @broadcast_0d(%a: f32) { + %1 = vector.broadcast %a : f32 to vector + // CHECK: ( 42 ) + vector.print %1: vector + + %2 = vector.broadcast %1 : vector to vector + // CHECK: ( 42 ) + vector.print %2: vector + + %3 = vector.broadcast %1 : vector to vector<1xf32> + // CHECK: ( 42 ) + vector.print %3: vector<1xf32> + + %4 = vector.broadcast %1 : vector to vector<2xf32> + // CHECK: ( 42, 42 ) + vector.print %4: vector<2xf32> + + %5 = vector.broadcast %1 : vector to vector<2x1xf32> + // CHECK: ( ( 42 ), ( 42 ) ) + vector.print %5: vector<2x1xf32> + + %6 = vector.broadcast %1 : vector to vector<2x3xf32> + // CHECK: ( ( 42, 42, 42 ), ( 42, 42, 42 ) ) + vector.print %6: vector<2x3xf32> + return +} + func @entry() { %0 = arith.constant 42.0 : f32 %1 = arith.constant dense<0.0> : vector @@ -39,6 +66,7 @@ %4 = arith.constant 42.0 : f32 call @splat_0d(%4) : (f32) -> () + call @broadcast_0d(%4) : (f32) -> () return }