diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -675,7 +675,7 @@ position and inserts the source into the destination at the proper position. Note that this instruction resembles vector.insert, but is restricted to 0-D - and 1-D vectors and relaxed to dynamic indices. + and 1-D vectors and relaxed to dynamic indices. It is meant to be closer to LLVM's version: https://llvm.org/docs/LangRef.html#insertelement-instruction @@ -2030,13 +2030,14 @@ def Vector_BitCastOp : Vector_Op<"bitcast", [NoSideEffect, AllRanksMatch<["source", "result"]>]>, - Arguments<(ins AnyVector:$source)>, - Results<(outs AnyVector:$result)>{ + Arguments<(ins AnyVectorOfAnyRank:$source)>, + Results<(outs AnyVectorOfAnyRank:$result)>{ let summary = "bitcast casts between vectors"; let description = [{ The bitcast operation casts between vectors of the same rank, the minor 1-D vector size is casted to a vector with a different element type but same - bitwidth. + bitwidth. In case of 0-D vectors, the bitwidth of element types must be + equal. Example: @@ -2049,6 +2050,9 @@ // Example casting to an element type of the same size. %5 = vector.bitcast %4 : vector<5x1x4x3xf32> to vector<5x1x4x3xi32> + + // Example casting of 0-D vectors. + %7 = vector.bitcast %6 : vector to vector ``` }]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -145,9 +145,9 @@ LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Only 1-D vectors can be lowered to LLVM. - VectorType resultTy = bitCastOp.getType(); - if (resultTy.getRank() != 1) + // Only 0-D and 1-D vectors can be lowered to LLVM. + VectorType resultTy = bitCastOp.getResultVectorType(); + if (resultTy.getRank() > 1) return failure(); Type newResultTy = typeConverter->convertType(resultTy); rewriter.replaceOpWithNewOp(bitCastOp, newResultTy, diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1613,8 +1613,8 @@ static_cast(destVectorType.getRank()))) return op.emitOpError("expected position attribute rank + source rank to " "match dest vector rank"); - if (!srcVectorType && (positionAttr.size() != - static_cast(destVectorType.getRank()))) + if (!srcVectorType && + (positionAttr.size() != static_cast(destVectorType.getRank()))) return op.emitOpError( "expected position attribute rank to match the dest vector rank"); for (auto en : llvm::enumerate(positionAttr)) { @@ -3724,12 +3724,20 @@ } DataLayout dataLayout = DataLayout::closest(op); - if (dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()) * - sourceVectorType.getShape().back() != - dataLayout.getTypeSizeInBits(resultVectorType.getElementType()) * - resultVectorType.getShape().back()) + auto sourceElementBits = + dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()); + auto resultElementBits = + dataLayout.getTypeSizeInBits(resultVectorType.getElementType()); + + if (sourceVectorType.getRank() == 0) { + if (sourceElementBits != resultElementBits) + return op.emitOpError("source/result bitwidth of the 0-D vector element " + "types must be equal"); + } else if (sourceElementBits * sourceVectorType.getShape().back() != + resultElementBits * resultVectorType.getShape().back()) { return op.emitOpError( "source/result bitwidth of the minor 1-D vectors must be equal"); + } return success(); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1,6 +1,20 @@ // RUN: mlir-opt %s -convert-vector-to-llvm -split-input-file | FileCheck %s +func @bitcast_f32_to_i32_vector_0d(%input: vector) -> vector { + %0 = vector.bitcast %input : vector to vector + return %0 : vector +} + +// CHECK-LABEL: @bitcast_f32_to_i32_vector_0d +// CHECK-SAME: %[[input:.*]]: vector +// CHECK: %[[vec_f32_1d:.*]] = builtin.unrealized_conversion_cast %[[input]] : vector to vector<1xf32> +// CHECK: %[[vec_i32_1d:.*]] = llvm.bitcast %[[vec_f32_1d]] : vector<1xf32> to vector<1xi32> +// CHECK: %[[vec_i32_0d:.*]] = builtin.unrealized_conversion_cast %[[vec_i32_1d]] : vector<1xi32> to vector +// CHECK: return %[[vec_i32_0d]] : vector + +// ----- + func @bitcast_f32_to_i32_vector(%input: vector<16xf32>) -> vector<16xi32> { %0 = vector.bitcast %input : vector<16xf32> to vector<16xi32> return %0 : vector<16xi32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -79,7 +79,7 @@ } // ----- - + func @extract_element(%arg0: vector<4xf32>) { %c = arith.constant 3 : i32 // expected-error@+1 {{expected position for 1-D vector}} @@ -1007,6 +1007,20 @@ // ----- +func @bitcast_rank_mismatch_to_0d(%arg0 : vector<1xf32>) { + // expected-error@+1 {{op failed to verify that all of {source, result} have same rank}} + %0 = vector.bitcast %arg0 : vector<1xf32> to vector +} + +// ----- + +func @bitcast_rank_mismatch_from_0d(%arg0 : vector) { + // expected-error@+1 {{op failed to verify that all of {source, result} have same rank}} + %0 = vector.bitcast %arg0 : vector to vector<1xf32> +} + +// ----- + func @bitcast_rank_mismatch(%arg0 : vector<5x1x3x2xf32>) { // expected-error@+1 {{op failed to verify that all of {source, result} have same rank}} %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x3x2xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -412,8 +412,9 @@ func @bitcast(%arg0 : vector<5x1x3x2xf32>, %arg1 : vector<8x1xi32>, %arg2 : vector<16x1x8xi8>, - %arg3 : vector<8x2x1xindex>) - -> (vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>) { + %arg3 : vector<8x2x1xindex>, + %arg4 : vector) + -> (vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>, vector) { // CHECK: vector.bitcast %{{.*}} : vector<5x1x3x2xf32> to vector<5x1x3x4xf16> %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x1x3x4xf16> @@ -439,7 +440,10 @@ // CHECK-NEXT: vector.bitcast %{{.*}} : vector<8x2x1xindex> to vector<8x2x2xf32> %7 = vector.bitcast %arg3 : vector<8x2x1xindex> to vector<8x2x2xf32> - return %0, %1, %2, %3, %4, %5, %6, %7 : vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32> + // CHECK: vector.bitcast %{{.*}} : vector to vector + %8 = vector.bitcast %arg4 : vector to vector + + return %0, %1, %2, %3, %4, %5, %6, %7, %8 : vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>, vector } // CHECK-LABEL: @vector_fma diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -55,6 +55,19 @@ return } +func @bitcast_0d() { + %0 = arith.constant 42 : i32 + %1 = arith.constant dense<0> : vector + %2 = vector.insertelement %0, %1[] : vector + %3 = vector.bitcast %2 : vector to vector + %4 = vector.extractelement %3[] : vector + %5 = arith.bitcast %4 : f32 to i32 + // CHECK: 42 + vector.print %5: i32 + return +} + + func @entry() { %0 = arith.constant 42.0 : f32 %1 = arith.constant dense<0.0> : vector @@ -68,5 +81,7 @@ call @splat_0d(%4) : (f32) -> () call @broadcast_0d(%4) : (f32) -> () + call @bitcast_0d() : () -> () + return }