diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -1439,8 +1439,7 @@ namespace mlir { -template -UnderlyingTy extractValue(Attribute attr); +template UnderlyingTy extractValue(Attribute attr); /// This class implements underlying value iterator for attributes. It is /// templated by both the value and attribute types. @@ -1457,8 +1456,7 @@ }; // Explicit instantiations and definitions of attribute value iterators. -template <> -inline int64_t extractValue(Attribute attr) { +template <> inline int64_t extractValue(Attribute attr) { return attr.cast().getValue().getSExtValue(); } using i64_attr_iterator = attr_value_iterator; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -70,6 +70,17 @@ rewriter.getI64ArrayAttr(pos)); } +// Helper that picks the proper sequence for inserting. +static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, + Value into, int64_t offset) { + auto vectorType = into.getType().cast(); + if (vectorType.getRank() > 1) + return rewriter.create(loc, from, into, offset); + return rewriter.create( + loc, vectorType, from, into, + rewriter.create(loc, offset)); +} + // Helper that picks the proper sequence for extracting. static Value extractOne(ConversionPatternRewriter &rewriter, LLVMTypeConverter &lowering, Location loc, Value val, @@ -86,6 +97,27 @@ rewriter.getI64ArrayAttr(pos)); } +// Helper that picks the proper sequence for extracting. +static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, + int64_t offset) { + auto vectorType = vector.getType().cast(); + if (vectorType.getRank() > 1) + return rewriter.create(loc, vector, offset); + return rewriter.create( + loc, vectorType.getElementType(), vector, + rewriter.create(loc, offset)); +} + +// Helper that returns a subset of `arrayAttr` as a vector of int64_t. +static SmallVector getI64SubArray(ArrayAttr arrayAttr, + unsigned dropFront = 1, + unsigned dropBack = 0) { + auto range = + llvm::make_range(i64_attr_iterator(arrayAttr.begin() + dropFront), + i64_attr_iterator(arrayAttr.end() - dropBack)); + return llvm::to_vector<4>(range); +} + class VectorBroadcastOpConversion : public LLVMOpLowering { public: explicit VectorBroadcastOpConversion(MLIRContext *context, @@ -464,6 +496,132 @@ } }; +// RewritePattern for InsertStridedSliceOp where source and destination vectors +// have different ranks. In this case: +// 1. the proper subvector is extracted from the destination vector +// 2. a new InsertStridedSlice op is created to insert the source in the +// destination subvector +// 3. the destination subvector is inserted back in the proper place +// 4. the op is replaced by the result of step 3. +// The new InsertStridedSlice from step 2. will be picked up by a +// `VectorInsertStridedSliceOpSameRankRewritePattern`. +class VectorInsertStridedSliceOpDifferentRankRewritePattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(InsertStridedSliceOp op, + PatternRewriter &rewriter) const override { + auto srcType = op.getSourceVectorType(); + auto dstType = op.getDestVectorType(); + + if (op.offsets().getValue().empty()) + return matchFailure(); + + auto loc = op.getLoc(); + int64_t rankDiff = dstType.getRank() - srcType.getRank(); + assert(rankDiff >= 0); + if (rankDiff == 0) + return matchFailure(); + + int64_t rankRest = dstType.getRank() - rankDiff; + // Extract / insert the subvector of matching rank and InsertStridedSlice + // on it. + Value extracted = + rewriter.create(loc, op.dest(), + getI64SubArray(op.offsets(), /*dropFront=*/0, + /*dropFront=*/rankRest)); + // A different pattern will kick in for InsertStridedSlice with matching + // ranks. + auto stridedSliceInnerOp = rewriter.create( + loc, op.source(), extracted, + getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), + getI64SubArray(op.strides(), /*dropFront=*/rankDiff)); + Value inserted = rewriter.create( + loc, stridedSliceInnerOp.getResult(), op.dest(), + getI64SubArray(op.offsets(), /*dropFront=*/0, + /*dropFront=*/rankRest)); + rewriter.replaceOp(op, inserted); + return matchSuccess(); + } +}; + +// RewritePattern for InsertStridedSliceOp where source and destination vectors +// have the same rank. In this case, we reduce +// 1. the proper subvector is extracted from the destination vector +// 2. a new InsertStridedSlice op is created to insert the source in the +// destination subvector +// 3. the destination subvector is inserted back in the proper place +// 4. the op is replaced by the result of step 3. +// The new InsertStridedSlice from step 2. will be picked up by a +// `VectorInsertStridedSliceOpSameRankRewritePattern`. +class VectorInsertStridedSliceOpSameRankRewritePattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(InsertStridedSliceOp op, + PatternRewriter &rewriter) const override { + auto srcType = op.getSourceVectorType(); + auto dstType = op.getDestVectorType(); + + if (op.offsets().getValue().empty()) + return matchFailure(); + + int64_t rankDiff = dstType.getRank() - srcType.getRank(); + assert(rankDiff >= 0); + if (rankDiff != 0) + return matchFailure(); + + if (srcType == dstType) { + rewriter.replaceOp(op, op.source()); + return matchSuccess(); + } + + int64_t offset = + op.offsets().getValue().front().cast().getInt(); + int64_t size = srcType.getShape().front(); + int64_t stride = + op.strides().getValue().front().cast().getInt(); + + auto loc = op.getLoc(); + Value res = op.dest(); + // For each slice of the source vector along the most major dimension. + for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; + off += stride, ++idx) { + // 1. extract the proper subvector (or element) from source + Value extractedSource = extractOne(rewriter, loc, op.source(), idx); + if (extractedSource.getType().isa()) { + // 2. If we have a vector, extract the proper subvector from destination + // Otherwise we are at the element level and no need to recurse. + Value extractedDest = extractOne(rewriter, loc, op.dest(), off); + // 3. Reduce the problem to lowering a new InsertStridedSlice op with + // smaller rank. + InsertStridedSliceOp insertStridedSliceOp = + rewriter.create( + loc, extractedSource, extractedDest, + getI64SubArray(op.offsets()), getI64SubArray(op.strides())); + // Call matchAndRewrite recursively from within the pattern. This + // circumvents the current limitation that a given pattern cannot + // be called multiple times by the PatternRewrite infrastructure (to + // avoid infinite recursion, but in this case we know what we are + // doing). + // TODO(rriddle, nicolasvasilache) Implement something like a hook for + // a potential function that must decrease and allow the same pattern + // multiple times. + if (!matchAndRewrite(insertStridedSliceOp, rewriter)) + return matchFailure(); + extractedSource = insertStridedSliceOp; + } + // 4. Insert the extractedSource into the res vector. + res = insertOne(rewriter, loc, extractedSource, res, off); + } + + rewriter.replaceOp(op, res); + return matchSuccess(); + } +}; + class VectorOuterProductOpConversion : public LLVMOpLowering { public: explicit VectorOuterProductOpConversion(MLIRContext *context, @@ -725,40 +883,10 @@ } }; -static SmallVector getI64SubArray(ArrayAttr arrayAttr, - unsigned drop_front = 1, - unsigned drop_back = 0) { - auto range = - llvm::make_range(i64_attr_iterator(arrayAttr.begin() + drop_front), - i64_attr_iterator(arrayAttr.end() - drop_back)); - return llvm::to_vector<4>(range); -} - -static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, - int64_t offset) { - auto vectorType = vector.getType().cast(); - if (vectorType.getRank() > 1) - return rewriter.create(loc, vector, offset); - return rewriter.create( - loc, vectorType.getElementType(), vector, - rewriter.create(loc, offset)); -} - -static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, - Value into, int64_t offset) { - auto vectorType = into.getType().cast(); - if (vectorType.getRank() > 1) - return rewriter.create(loc, from, into, offset); - return rewriter.create( - loc, vectorType, from, into, - rewriter.create(loc, offset)); -} - /// Progressive lowering of StridedSliceOp to either: /// 1. extractelement + insertelement for the 1-D case /// 2. extract + optional strided_slice + insert for the n-D case. -class VectorStridedSliceOpRewritePattern - : public OpRewritePattern { +class VectorStridedSliceOpConversion : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -811,7 +939,9 @@ void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { MLIRContext *ctx = converter.getDialect()->getContext(); - patterns.insert(ctx); + patterns.insert(ctx); patterns.insert, %arg1: vector<4x8xf32>, %arg2: vector<4x8x16xf32>) { // CHECK-LABEL: llvm.func @strided_slice( - %0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> // CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float // CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>"> @@ -483,4 +482,45 @@ return } +func @insert_strided_slice(%a: vector<2x2xf32>, %b: vector<4x4xf32>, %c: vector<4x4x4xf32>) { +// CHECK-LABEL: @insert_strided_slice + + %0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32> +// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x [4 x <4 x float>]]"> +// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [4 x <4 x float>]]"> + + %1 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// +// Subvector vector<2xf32> @0 into vector<4xf32> @2 +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <2 x float>]"> +// CHECK-NEXT: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x <4 x float>]"> +// Element @0 -> element @2 +// CHECK-NEXT: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK-NEXT: llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> +// Element @1 -> element @3 +// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK-NEXT: llvm.mlir.constant(3 : index) : !llvm.i64 +// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> +// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x <4 x float>]"> +// +// Subvector vector<2xf32> @1 into vector<4xf32> @3 +// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <2 x float>]"> +// CHECK-NEXT: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <4 x float>]"> +// Element @0 -> element @2 +// CHECK-NEXT: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK-NEXT: llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> +// Element @1 -> element @3 +// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK-NEXT: llvm.mlir.constant(3 : index) : !llvm.i64 +// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> +// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]"> + + return +}