diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -765,6 +765,113 @@ let hasVerifier = 1; } +def Vector_ScalableInsertOp : + Vector_Op<"scalable.insert", [NoSideEffect, + AllElementTypesMatch<["source", "dest"]>, + AllTypesMatch<["dest", "res"]>, + PredOpTrait<"position is a multiple of the source length.", + CPred< + "(getPos() % getSourceVectorType().getNumElements()) == 0" + >>]>, + Arguments<(ins VectorOfRank<[1]>:$source, + ScalableVectorOfRank<[1]>:$dest, + I64Attr:$pos)>, + Results<(outs ScalableVectorOfRank<[1]>:$res)> { + let summary = "insert subvector into scalable vector operation"; + let description = [{ + Takes rank-1 fixed-length or scalable source `n-D` subvector, a rank-1 + destination `m-D` vector, a position `pos` within the destination vector. + If `source` is fixed-length, it inserts the elements `(0, ..., n-1)` of the + `n-D` subvector into the scalable vector, starting from the specified + position `pos`, and up to `(pos + n - 1)`. If source vector is scalable, + the insertion position will be scaled by the runtime vector scale + `vscale`, and the source elements `(0, ..., vscale * n - 1)` will be + inserted in the destination from position `(vscale * pos)`, and up to + `(vscale * pos + vscale * n - 1)`. + + The insertion position must be a multiple of the minimum size of the source + vector. For the operation to be well defined, the source vector must fit in + the destination vector from the specified position. Since the destination + vector is scalable and its runtime length is unknown, it can't be verified + and/or guaranteed at compile time. + + Example: + + ```mlir + %2 = vector.scalable.insert %0, %1[8] : vector<4xf32> into vector<[16]xf32> + %5 = vector.scalable.insert %3, %4[0] : vector<8xf32> into vector<[4]xf32> + %8 = vector.scalable.insert %6, %7[0] : vector<[4]xf32> into vector<[8]xf32> + ``` + + Invalid example: + ```mlir + %2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32> + ``` + }]; + + let assemblyFormat = [{ + $source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest) + }]; + + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return getSource().getType().cast(); + } + VectorType getDestVectorType() { + return getDest().getType().cast(); + } + }]; +} + +def Vector_ScalableExtractOp : + Vector_Op<"scalable.extract", [NoSideEffect, + AllElementTypesMatch<["source", "res"]>, + PredOpTrait<"position is a multiple of the result length.", + CPred< + "(getPos() % getResultVectorType().getNumElements()) == 0" + >>]>, + Arguments<(ins ScalableVectorOfRank<[1]>:$source, + I64Attr:$pos)>, + Results<(outs VectorOfRank<[1]>:$res)> { + let summary = "extract subvector from scalable vector operation"; + let description = [{ + Takes rank-1 `n-D` source vector and a position `pos` within the source + vector, and extracts the elements `(pos, ..., pos + m - 1)` from the `n-D` + source vector into a new rank-1 `m-D` vector. + + The extraction position must be a multiple of the minimum size of the result + vector. For the operation to be well defined, the destination vector must + fit in the source vector from the specified position. Since the source + vector is scalable and its runtime length is unknown, it can't be verified + and/or guaranteed at compile time. + + Example: + + ```mlir + %1 = vector.scalable.extract %0[8] : vector<4xf32> from vector<[8]xf32> + %3 = vector.scalable.extract %2[0] : vector<[4]xf32> from vector<[8]xf32> + ``` + + Invalid example: + ```mlir + %1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32> + ``` + }]; + + let assemblyFormat = [{ + $source `[` $pos `]` attr-dict `:` type($res) `from` type($source) + }]; + + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return getSource().getType().cast(); + } + VectorType getResultVectorType() { + return getRes().getType().cast(); + } + }]; +} + def Vector_InsertMapOp : Vector_Op<"insert_map", [NoSideEffect, AllTypesMatch<["dest", "result"]>]>, Arguments<(ins AnyVector:$vector, AnyVector:$dest, Variadic:$ids)>, diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -564,11 +564,39 @@ == }] # allowedlength>)>]>; +// Whether the number of elements of a fixed-length vector is from the given +// `allowedRanks` list +class IsFixedVectorOfRankPred allowedRanks> : + And<[IsFixedVectorTypePred, + Or().getRank() + == }] + # allowedlength>)>]>; + +// Whether the number of elements of a scalable vector is from the given +// `allowedRanks` list +class IsScalableVectorOfRankPred allowedRanks> : + And<[IsScalableVectorTypePred, + Or().getRank() + == }] + # allowedlength>)>]>; + // Any vector where the rank is from the given `allowedRanks` list class VectorOfRank allowedRanks> : Type< IsVectorOfRankPred, " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; +// Any fixed-length vector where the rank is from the given `allowedRanks` list +class FixedVectorOfRank allowedRanks> : Type< + IsFixedVectorOfRankPred, + " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; + +// Any scalable vector where the rank is from the given `allowedRanks` list +class ScalableVectorOfRank allowedRanks> : Type< + IsScalableVectorOfRankPred, + " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; + // Any vector where the rank is from the given `allowedRanks` list and the type // is from the given `allowedTypes` list class VectorOfRankAndType allowedRanks, diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -737,6 +737,37 @@ } }; +/// Lower vector.scalable.insert ops to LLVM vector.insert +struct VectorScalableInsertOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + vector::ScalableInsertOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); + return success(); + } +}; + +/// Lower vector.scalable.extract ops to LLVM vector.extract +struct VectorScalableExtractOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + vector::ScalableExtractOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + extOp, typeConverter->convertType(extOp.getResultVectorType()), + adaptor.getSource(), adaptor.getPos()); + return success(); + } +}; + /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. /// /// Example: @@ -1214,7 +1245,9 @@ vector::MaskedStoreOpAdaptor>, VectorGatherOpConversion, VectorScatterOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, - VectorSplatOpLowering, VectorSplatNdOpLowering>(converter); + VectorSplatOpLowering, VectorSplatNdOpLowering, + VectorScalableInsertOpLowering, VectorScalableExtractOpLowering>( + converter); // Transfer ops with rank > 1 are handled by VectorToSCF. populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1814,3 +1814,25 @@ // CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0 : i32, 0 : i32, 0 : i32, 0 : i32] // CHECK-NEXT: %[[SCALE:[0-9]+]] = arith.mulf %[[A]], %[[SPLAT]] : vector<4xf32> // CHECK-NEXT: return %[[SCALE]] : vector<4xf32> + +// ----- + +// CHECK-LABEL: @vector_scalable_insert +// CHECK-SAME: %[[SUB:.*]]: vector<4xf32>, %[[SV:.*]]: vector<[4]xf32> +func.func @vector_scalable_insert(%sub: vector<4xf32>, %dsv: vector<[4]xf32>) -> vector<[4]xf32> { + // CHECK-NEXT: %[[TMP:.*]] = llvm.intr.vector.insert %[[SUB]], %[[SV]][0] : vector<4xf32> into vector<[4]xf32> + %0 = vector.scalable.insert %sub, %dsv[0] : vector<4xf32> into vector<[4]xf32> + // CHECK-NEXT: llvm.intr.vector.insert %[[SUB]], %[[TMP]][4] : vector<4xf32> into vector<[4]xf32> + %1 = vector.scalable.insert %sub, %0[4] : vector<4xf32> into vector<[4]xf32> + return %1 : vector<[4]xf32> +} + +// ----- + +// CHECK-LABEL: @vector_scalable_extract +// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32> +func.func @vector_scalable_extract(%vec: vector<[4]xf32>) -> vector<8xf32> { + // CHECK-NEXT: %{{.*}} = llvm.intr.vector.extract %[[VEC]][0] : vector<8xf32> from vector<[4]xf32> + %0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32> + return %0 : vector<8xf32> +} diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1612,3 +1612,17 @@ } return } + +// ----- + +func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) { + // expected-error@+1 {{op failed to verify that position is a multiple of the source length.}} + %0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32> +} + +// ----- + +func.func @vector_scalable_extract_unaligned(%vec: vector<[16]xf32>) { + // expected-error@+1 {{op failed to verify that position is a multiple of the result length.}} + %0 = vector.scalable.extract %vec[5] : vector<4xf32> from vector<[16]xf32> +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -771,4 +771,26 @@ return %2 : vector<4xi32> } - +// CHECK-LABEL: func @vector_scalable_insert( +// CHECK-SAME: %[[SUB0:.*]]: vector<4xi32>, %[[SUB1:.*]]: vector<8xi32>, +// CHECK-SAME: %[[SUB2:.*]]: vector<[4]xi32>, %[[SV:.*]]: vector<[8]xi32> +func.func @vector_scalable_insert(%sub0: vector<4xi32>, %sub1: vector<8xi32>, + %sub2: vector<[4]xi32>, %sv: vector<[8]xi32>) { + // CHECK-NEXT: vector.scalable.insert %[[SUB0]], %[[SV]][12] : vector<4xi32> into vector<[8]xi32> + %0 = vector.scalable.insert %sub0, %sv[12] : vector<4xi32> into vector<[8]xi32> + // CHECK-NEXT: vector.scalable.insert %[[SUB1]], %[[SV]][0] : vector<8xi32> into vector<[8]xi32> + %1 = vector.scalable.insert %sub1, %sv[0] : vector<8xi32> into vector<[8]xi32> + // CHECK-NEXT: vector.scalable.insert %[[SUB2]], %[[SV]][0] : vector<[4]xi32> into vector<[8]xi32> + %2 = vector.scalable.insert %sub2, %sv[0] : vector<[4]xi32> into vector<[8]xi32> + return + } + +// CHECK-LABEL: func @vector_scalable_extract( +// CHECK-SAME: %[[SV:.*]]: vector<[8]xi32> +func.func @vector_scalable_extract(%sv: vector<[8]xi32>) { + // CHECK-NEXT: vector.scalable.extract %[[SV]][0] : vector<16xi32> from vector<[8]xi32> + %0 = vector.scalable.extract %sv[0] : vector<16xi32> from vector<[8]xi32> + // CHECK-NEXT: vector.scalable.extract %[[SV]][0] : vector<[4]xi32> from vector<[8]xi32> + %1 = vector.scalable.extract %sv[0] : vector<[4]xi32> from vector<[8]xi32> + return + }