diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2052,6 +2052,67 @@ let hasVerifier = 1; } +def Vector_ScalableCastOp : + Vector_Op<"scalable_cast", [NoSideEffect, + AllElementTypesMatch<["source", "result"]>, + PredOpTrait<"is a cast between scalable and fixed-length vectors", CPred< + "getSourceVectorType().isScalable() != getResultVectorType().isScalable()" + >>, + PredOpTrait<"fixed number of elements is multiple of scalable number of elements", + Or<[ + And<[CPred<"getSourceVectorType().isScalable()">, + Neg>]>, + And<[CPred<"getResultVectorType().isScalable()">, + Neg>]> + ]>> + ]>, + Arguments<(ins VectorOfRank<[1]>:$source)>, + Results<(outs VectorOfRank<[1]>:$result)> { + let summary = "scalable_cast casts between scalable and fixed-length vectors"; + let description = [{ + The scalable_cast operation performs the cast between a scalable vector and + a fixed-length vector by assuming a vector scale. + + This operation allows to interface VLS (Vector-Length Specific) code with + scalable vector functions and operations, such as the ones found in the + ArmSVE dialect. + + The scalable_cast operation doesn't allow a cast between element types. If + needed, an element-type cast should be performed before or after the cast + between scalable and fixed-length types. + + For the cast to be valid, the number of dimensions of the fixed-length + vector must be a multiple of the scalable vector. + + Example: + + ```mlir + // Casting from a scalable vector to a fixed-length vector assuming a + // vector scale of 2 + %1 = vector.scalable_cast %0 : vector<[32]xf32> to vector<64xf32> + + // Casting from a fixed-length vector to a scalable vector assuming a + // vector scale of 16 + %3 = vector.scalable_cast %2 : vector<128xf32> to vector<[8]xf32> + + // Casting from a fixed-length vector to a scalable vector assuming a + // vector scale of 1 + %5 = vector.scalable_cast %4 : vector<8xbf16> to vector<[8]xbf16> + + ``` + }]; + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return getSource().getType().cast(); + } + VectorType getResultVectorType() { + return getResult().getType().cast(); + } + }]; + let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; + let hasFolder = 1; +} + def Vector_BitCastOp : Vector_Op<"bitcast", [NoSideEffect, AllRanksMatch<["source", "result"]>]>, Arguments<(ins AnyVectorOfAnyRank:$source)>, diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -791,6 +791,29 @@ } }; +class VectorScalableCastOpRewritePattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ScalableCastOp op, + PatternRewriter &rewriter) const override { + auto resultType = op.getResultVectorType(); + if (resultType.isScalable()) { + // Fixed-length to scalable cast + Value insVec = rewriter.create( + op.getLoc(), resultType, rewriter.getZeroAttr(resultType)); + rewriter.replaceOpWithNewOp( + op, op.getSource(), insVec, rewriter.getI64IntegerAttr(0)); + } else { + // Scalable to fixed-length cast + rewriter.replaceOpWithNewOp( + op, resultType, op.getSource(), rewriter.getI64IntegerAttr(0)); + } + return success(); + } +}; + /// Returns the strides if the memory underlying `memRefType` has a contiguous /// static layout. static llvm::Optional> @@ -1196,7 +1219,8 @@ LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions, bool force32BitVectorIndices) { MLIRContext *ctx = converter.getDialect()->getContext(); - patterns.add(ctx); + patterns.add( + ctx); populateVectorInsertExtractStridedSliceTransforms(patterns); patterns.add(converter, reassociateFPReductions); patterns.add(ctx, force32BitVectorIndices); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4201,6 +4201,17 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// VectorScalableCastOp +//===----------------------------------------------------------------------===// + +OpFoldResult ScalableCastOp::fold(ArrayRef operands) { + if (auto otherOp = getSource().getDefiningOp()) + if (getResult().getType() == otherOp.getSource().getType()) + return otherOp.getSource(); + return {}; +} + //===----------------------------------------------------------------------===// // VectorBitCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1814,3 +1814,30 @@ // CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0 : i32, 0 : i32, 0 : i32, 0 : i32] // CHECK-NEXT: %[[SCALE:[0-9]+]] = arith.mulf %[[A]], %[[SPLAT]] : vector<4xf32> // CHECK-NEXT: return %[[SCALE]] : vector<4xf32> + +// ----- + +// CHECK-LABEL: @scalable_cast_fixed_to_scalable +// CHECK-SAME: %[[A:arg[0-9]+]]: vector<8xi32> +func.func @scalable_cast_fixed_to_scalable(%a: vector<8xi32>) + -> vector<[4]xi32> { + // CHECK-NEXT: %[[C0:.*]] = arith.constant dense<0> : vector<[4]xi32> + // CHECK-NEXT: %[[R:[0-9]+]] = llvm.intr.experimental.vector.insert + // CHECK-SAME: %[[A]], %[[C0]][0] : vector<8xi32> into vector<[4]xi32> + // CHECK-NEXT: return %[[R]] : vector<[4]xi32> + %0 = vector.scalable_cast %a : vector<8xi32> to vector<[4]xi32> + return %0 : vector<[4]xi32> +} + +// ----- + +// CHECK-LABEL: @scalable_cast_scalable_to_fixed +// CHECK-SAME: %[[A:arg[0-9]+]]: vector<[4]xf32> +func.func @scalable_cast_scalable_to_fixed(%a: vector<[4]xf32>) + -> vector<8xf32> { + // CHECK-NEXT: %[[R:[0-9]+]] = llvm.intr.experimental.vector.extract + // CHECK-SAME: %[[A]][0] : vector<8xf32> from vector<[4]xf32> + // CHECK-NEXT: return %[[R]] : vector<8xf32> + %0 = vector.scalable_cast %a : vector<[4]xf32> to vector<8xf32> + return %0 : vector<8xf32> +} diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1605,3 +1605,14 @@ %s = vector.reduction , %a : vector<4xf32> into f32 return %s : f32 } + +// ----- + +// CHECK-LABEL: func @scalable_cast_fold_complementary +// CHECK-SAME: (%[[V:.+]]: vector<8xf32>) +// CHECK-NEXT: return %[[V]] : vector<8xf32> +func.func @scalable_cast_fold_complementary(%a : vector<8xf32>) -> vector<8xf32> { + %0 = vector.scalable_cast %a : vector<8xf32> to vector<[4]xf32> + %1 = vector.scalable_cast %0 : vector<[4]xf32> to vector<8xf32> + return %1 : vector<8xf32> +} diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1612,3 +1612,27 @@ } return } + +// ----- + +func.func @scalable_cast_element_type_mismatch(%vin : vector<8xf32>) { + // expected-error@+1 {{op failed to verify that all of {source, result} have same element type}} + %0 = vector.scalable_cast %vin : vector<8xf32> to vector<[4]xbf16> + return +} + +// ----- + +func.func @scalable_cast_non_scalable_cast(%vin : vector<8xf32>) { + // expected-error@+1 {{op failed to verify that is a cast between scalable and fixed-length vector}} + %0 = vector.scalable_cast %vin : vector<8xf32> to vector<8xf32> + return +} + +// ----- + +func.func @scalable_cast_invalid_cast(%vin : vector<8xf32>) { + // expected-error@+1 {{op failed to verify that fixed number of elements is multiple of scalable number of elements}} + %0 = vector.scalable_cast %vin : vector<8xf32> to vector<[16]xf32> + return +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -771,4 +771,16 @@ return %2 : vector<4xi32> } - +// CHECK-LABEL: @vector_scalable_casts +func.func @vector_scalable_casts(%0: vector<8xf32>, + %1: vector<[4]xf32>) { + // CHECK-NEXT: %{{.*}} = vector.scalable_cast %{{.*}}: vector<8xf32> to vector<[8]xf32> + %2 = vector.scalable_cast %0 : vector<8xf32> to vector<[8]xf32> + // CHECK-NEXT: %{{.*}} = vector.scalable_cast %{{.*}}: vector<8xf32> to vector<[4]xf32> + %3 = vector.scalable_cast %0 : vector<8xf32> to vector<[4]xf32> + // CHECK-NEXT: %{{.*}} = vector.scalable_cast %{{.*}}: vector<[4]xf32> to vector<4xf32> + %4 = vector.scalable_cast %1 : vector<[4]xf32> to vector<4xf32> + // CHECK-NEXT: %{{.*}} = vector.scalable_cast %{{.*}}: vector<[4]xf32> to vector<8xf32> + %5 = vector.scalable_cast %1 : vector<[4]xf32> to vector<8xf32> + return +}