diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1525,6 +1525,41 @@ let hasFolder = 1; } +def Vector_BitCastOp : + Vector_Op<"bitcast", [NoSideEffect, AllRanksMatch<["source", "result"]>]>, + Arguments<(ins AnyVector:$source)>, + Results<(outs AnyVector:$result)>{ + let summary = "bitcast casts between vectors"; + let description = [{ + The bitcast operation casts between vectors of the same rank, the minor 1-D + vector size is casted to a vector with a different element type but same + bitwidth. + + Example: + + ```mlir + // Example casting to a smaller element type. + %1 = vector.bitcast %0 : vector<5x1x4x3xf32> to vector<5x1x4x6xi16> + + // Example casting to a bigger element type. + %3 = vector.bitcast %2 : vector<10x12x8xi8> to vector<10x12x2xi32> + + // Example casting to an element type of the same size. + %5 = vector.bitcast %4 : vector<5x1x4x3xf32> to vector<5x1x4x3xi32> + ``` + }]; + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return source().getType().cast(); + } + VectorType getResultVectorType() { + return getResult().getType().cast(); + } + }]; + let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; + let hasFolder = 1; +} + def Vector_TypeCastOp : Vector_Op<"type_cast", [NoSideEffect]>, Arguments<(ins StaticShapeMemRefOf<[AnyType]>:$memref)>, diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2300,6 +2300,42 @@ return {}; } +//===----------------------------------------------------------------------===// +// VectorBitCastOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(BitCastOp op) { + auto sourceVectorType = op.getSourceVectorType(); + auto resultVectorType = op.getResultVectorType(); + + for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) { + if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i)) + return op.emitOpError("dimension size mismatch at: ") << i; + } + + if (sourceVectorType.getElementTypeBitWidth() * + sourceVectorType.getShape().back() != + resultVectorType.getElementTypeBitWidth() * + resultVectorType.getShape().back()) + return op.emitOpError( + "source/result bitwidth of the minor 1-D vectors must be equal"); + + return success(); +} + +OpFoldResult BitCastOp::fold(ArrayRef operands) { + // Nop cast. + if (source().getType() == result().getType()) + return source(); + + // Canceling bitcasts. + if (auto otherOp = source().getDefiningOp()) + if (result().getType() == otherOp.source().getType()) + return otherOp.source(); + + return {}; +} + //===----------------------------------------------------------------------===// // TypeCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -372,3 +372,16 @@ // CHECK: return return %1, %2 : vector<4x8xf32>, vector<4x9xf32> } + +// ----- + +// CHECK-LABEL: bitcast_folding +// CHECK-SAME: %[[A:.*]]: vector<4x8xf32> +// CHECK-SAME: %[[B:.*]]: vector<2xi32> +// CHECK: return %[[A]], %[[B]] : vector<4x8xf32>, vector<2xi32> +func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<4x8xf32>, vector<2xi32>) { + %0 = vector.bitcast %I1 : vector<4x8xf32> to vector<4x8xf32> + %1 = vector.bitcast %I2 : vector<2xi32> to vector<4xi16> + %2 = vector.bitcast %1 : vector<4xi16> to vector<2xi32> + return %0, %2 : vector<4x8xf32>, vector<2xi32> +} diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1065,6 +1065,34 @@ // ----- +func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) { + // expected-error@+1 {{must be vector of any type values}} + %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to f32 +} + +// ----- + +func @bitcast_rank_mismatch(%arg0 : vector<5x1x3x2xf32>) { + // expected-error@+1 {{op failed to verify that all of {source, result} have same rank}} + %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x3x2xf32> +} + +// ----- + +func @bitcast_shape_mismatch(%arg0 : vector<5x1x3x2xf32>) { + // expected-error@+1 {{op dimension size mismatch}} + %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x2x3x2xf32> +} + +// ----- + +func @bitcast_sizemismatch(%arg0 : vector<5x1x3x2xf32>) { + // expected-error@+1 {{op source/result bitwidth of the minor 1-D vectors must be equal}} + %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x1x3x3xf16> +} + +// ----- + func @reduce_unknown_kind(%arg0: vector<16xf32>) -> f32 { // expected-error@+1 {{'vector.reduction' op unknown reduction kind: joho}} %0 = vector.reduction "joho", %arg0 : vector<16xf32> into f32 diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -298,6 +298,33 @@ return %0, %1, %2, %3, %4 : vector<15x2xf32>, tuple, vector<12x2xf32>>, vector<8xf32>, vector<16xf32>, vector<16x1xf32> } +// CHECK-LABEL: @bitcast +func @bitcast(%arg0 : vector<5x1x3x2xf32>, + %arg1 : vector<8x1xi32>, + %arg2 : vector<16x1x8xi8>) + -> (vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>) { + + // CHECK: vector.bitcast %{{.*}} : vector<5x1x3x2xf32> to vector<5x1x3x4xf16> + %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x1x3x4xf16> + + // CHECK-NEXT: vector.bitcast %{{.*}} : vector<5x1x3x2xf32> to vector<5x1x3x8xi8> + %1 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x1x3x8xi8> + + // CHECK-NEXT: vector.bitcast %{{.*}} : vector<8x1xi32> to vector<8x4xi8> + %2 = vector.bitcast %arg1 : vector<8x1xi32> to vector<8x4xi8> + + // CHECK-NEXT: vector.bitcast %{{.*}} : vector<8x1xi32> to vector<8x1xf32> + %3 = vector.bitcast %arg1 : vector<8x1xi32> to vector<8x1xf32> + + // CHECK-NEXT: vector.bitcast %{{.*}} : vector<16x1x8xi8> to vector<16x1x2xi32> + %4 = vector.bitcast %arg2 : vector<16x1x8xi8> to vector<16x1x2xi32> + + // CHECK-NEXT: vector.bitcast %{{.*}} : vector<16x1x8xi8> to vector<16x1x4xi16> + %5 = vector.bitcast %arg2 : vector<16x1x8xi8> to vector<16x1x4xi16> + + return %0, %1, %2, %3, %4, %5 : vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16> +} + // CHECK-LABEL: @vector_fma func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) { // CHECK: vector.fma %{{.*}} : vector<8xf32>