diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -70,6 +70,17 @@ rewriter.getI64ArrayAttr(pos)); } +// Helper that picks the proper sequence for inserting. +static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, + Value into, int64_t offset) { + auto vectorType = into.getType().cast(); + if (vectorType.getRank() > 1) + return rewriter.create(loc, from, into, offset); + return rewriter.create( + loc, vectorType, from, into, + rewriter.create(loc, offset)); +} + // Helper that picks the proper sequence for extracting. static Value extractOne(ConversionPatternRewriter &rewriter, LLVMTypeConverter &lowering, Location loc, Value val, @@ -86,6 +97,32 @@ rewriter.getI64ArrayAttr(pos)); } +// Helper that picks the proper sequence for extracting. +static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, + int64_t offset) { + auto vectorType = vector.getType().cast(); + if (vectorType.getRank() > 1) + return rewriter.create(loc, vector, offset); + return rewriter.create( + loc, vectorType.getElementType(), vector, + rewriter.create(loc, offset)); +} + +// Helper that returns a subset of `arrayAttr` as a vector of int64_t. +// TODO(rriddle): Better support for attribute subtype forwarding + slicing. +static SmallVector getI64SubArray(ArrayAttr arrayAttr, + unsigned dropFront = 0, + unsigned dropBack = 0) { + assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); + auto range = arrayAttr.getAsRange(); + SmallVector res; + res.reserve(arrayAttr.size() - dropFront - dropBack); + for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; + it != eit; ++it) + res.push_back((*it).getValue().getSExtValue()); + return res; +} + class VectorBroadcastOpConversion : public LLVMOpLowering { public: explicit VectorBroadcastOpConversion(MLIRContext *context, @@ -464,6 +501,139 @@ } }; +// When ranks are different, InsertStridedSlice needs to extract a properly +// ranked vector from the destination vector into which to insert. This pattern +// only takes care of this part and forwards the rest of the conversion to +// another pattern that converts InsertStridedSlice for operands of the same +// rank. +// +// RewritePattern for InsertStridedSliceOp where source and destination vectors +// have different ranks. In this case: +// 1. the proper subvector is extracted from the destination vector +// 2. a new InsertStridedSlice op is created to insert the source in the +// destination subvector +// 3. the destination subvector is inserted back in the proper place +// 4. the op is replaced by the result of step 3. +// The new InsertStridedSlice from step 2. will be picked up by a +// `VectorInsertStridedSliceOpSameRankRewritePattern`. +class VectorInsertStridedSliceOpDifferentRankRewritePattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(InsertStridedSliceOp op, + PatternRewriter &rewriter) const override { + auto srcType = op.getSourceVectorType(); + auto dstType = op.getDestVectorType(); + + if (op.offsets().getValue().empty()) + return matchFailure(); + + auto loc = op.getLoc(); + int64_t rankDiff = dstType.getRank() - srcType.getRank(); + assert(rankDiff >= 0); + if (rankDiff == 0) + return matchFailure(); + + int64_t rankRest = dstType.getRank() - rankDiff; + // Extract / insert the subvector of matching rank and InsertStridedSlice + // on it. + Value extracted = + rewriter.create(loc, op.dest(), + getI64SubArray(op.offsets(), /*dropFront=*/0, + /*dropFront=*/rankRest)); + // A different pattern will kick in for InsertStridedSlice with matching + // ranks. + auto stridedSliceInnerOp = rewriter.create( + loc, op.source(), extracted, + getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), + getI64SubArray(op.strides(), /*dropFront=*/rankDiff)); + rewriter.replaceOpWithNewOp( + op, stridedSliceInnerOp.getResult(), op.dest(), + getI64SubArray(op.offsets(), /*dropFront=*/0, + /*dropFront=*/rankRest)); + return matchSuccess(); + } +}; + +// RewritePattern for InsertStridedSliceOp where source and destination vectors +// have the same rank. In this case, we reduce +// 1. the proper subvector is extracted from the destination vector +// 2. a new InsertStridedSlice op is created to insert the source in the +// destination subvector +// 3. the destination subvector is inserted back in the proper place +// 4. the op is replaced by the result of step 3. +// The new InsertStridedSlice from step 2. will be picked up by a +// `VectorInsertStridedSliceOpSameRankRewritePattern`. +class VectorInsertStridedSliceOpSameRankRewritePattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(InsertStridedSliceOp op, + PatternRewriter &rewriter) const override { + auto srcType = op.getSourceVectorType(); + auto dstType = op.getDestVectorType(); + + if (op.offsets().getValue().empty()) + return matchFailure(); + + int64_t rankDiff = dstType.getRank() - srcType.getRank(); + assert(rankDiff >= 0); + if (rankDiff != 0) + return matchFailure(); + + if (srcType == dstType) { + rewriter.replaceOp(op, op.source()); + return matchSuccess(); + } + + int64_t offset = + op.offsets().getValue().front().cast().getInt(); + int64_t size = srcType.getShape().front(); + int64_t stride = + op.strides().getValue().front().cast().getInt(); + + auto loc = op.getLoc(); + Value res = op.dest(); + // For each slice of the source vector along the most major dimension. + for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; + off += stride, ++idx) { + // 1. extract the proper subvector (or element) from source + Value extractedSource = extractOne(rewriter, loc, op.source(), idx); + if (extractedSource.getType().isa()) { + // 2. If we have a vector, extract the proper subvector from destination + // Otherwise we are at the element level and no need to recurse. + Value extractedDest = extractOne(rewriter, loc, op.dest(), off); + // 3. Reduce the problem to lowering a new InsertStridedSlice op with + // smaller rank. + InsertStridedSliceOp insertStridedSliceOp = + rewriter.create( + loc, extractedSource, extractedDest, + getI64SubArray(op.offsets(), /* dropFront=*/1), + getI64SubArray(op.strides(), /* dropFront=*/1)); + // Call matchAndRewrite recursively from within the pattern. This + // circumvents the current limitation that a given pattern cannot + // be called multiple times by the PatternRewrite infrastructure (to + // avoid infinite recursion, but in this case, infinite recursion + // cannot happen because the rank is strictly decreasing). + // TODO(rriddle, nicolasvasilache) Implement something like a hook for + // a potential function that must decrease and allow the same pattern + // multiple times. + auto success = matchAndRewrite(insertStridedSliceOp, rewriter); + (void)success; + assert(success && "Unexpected failure"); + extractedSource = insertStridedSliceOp; + } + // 4. Insert the extractedSource into the res vector. + res = insertOne(rewriter, loc, extractedSource, res, off); + } + + rewriter.replaceOp(op, res); + return matchSuccess(); + } +}; + class VectorOuterProductOpConversion : public LLVMOpLowering { public: explicit VectorOuterProductOpConversion(MLIRContext *context, @@ -725,49 +895,10 @@ } }; -// TODO(rriddle): Better support for attribute subtype forwarding + slicing. -static SmallVector getI64SubArray(ArrayAttr arrayAttr, - unsigned dropFront = 0, - unsigned dropBack = 0) { - assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); - auto range = arrayAttr.getAsRange(); - SmallVector res; - res.reserve(arrayAttr.size() - dropFront - dropBack); - for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; - it != eit; ++it) - res.push_back((*it).getValue().getSExtValue()); - return res; -} - -/// Emit the proper `ExtractOp` or `ExtractElementOp` depending on the rank -/// of `vector`. -static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, - int64_t offset) { - auto vectorType = vector.getType().cast(); - if (vectorType.getRank() > 1) - return rewriter.create(loc, vector, offset); - return rewriter.create( - loc, vectorType.getElementType(), vector, - rewriter.create(loc, offset)); -} - -/// Emit the proper `InsertOp` or `InsertElementOp` depending on the rank -/// of `vector`. -static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, - Value into, int64_t offset) { - auto vectorType = into.getType().cast(); - if (vectorType.getRank() > 1) - return rewriter.create(loc, from, into, offset); - return rewriter.create( - loc, vectorType, from, into, - rewriter.create(loc, offset)); -} - /// Progressive lowering of StridedSliceOp to either: /// 1. extractelement + insertelement for the 1-D case /// 2. extract + optional strided_slice + insert for the n-D case. -class VectorStridedSliceOpRewritePattern - : public OpRewritePattern { +class VectorStridedSliceOpConversion : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -821,7 +952,9 @@ void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { MLIRContext *ctx = converter.getDialect()->getContext(); - patterns.insert(ctx); + patterns.insert(ctx); patterns.insert, %arg1: vector<4x8xf32>, %arg2: vector<4x8x16xf32>) { // CHECK-LABEL: llvm.func @strided_slice( - %0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> // CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float // CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>"> @@ -483,4 +482,45 @@ return } +func @insert_strided_slice(%a: vector<2x2xf32>, %b: vector<4x4xf32>, %c: vector<4x4x4xf32>) { +// CHECK-LABEL: @insert_strided_slice + + %0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32> +// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x [4 x <4 x float>]]"> +// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [4 x <4 x float>]]"> + + %1 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// +// Subvector vector<2xf32> @0 into vector<4xf32> @2 +// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <2 x float>]"> +// CHECK-NEXT: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x <4 x float>]"> +// Element @0 -> element @2 +// CHECK-NEXT: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK-NEXT: llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> +// Element @1 -> element @3 +// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK-NEXT: llvm.mlir.constant(3 : index) : !llvm.i64 +// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> +// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x <4 x float>]"> +// +// Subvector vector<2xf32> @1 into vector<4xf32> @3 +// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <2 x float>]"> +// CHECK-NEXT: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <4 x float>]"> +// Element @0 -> element @2 +// CHECK-NEXT: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK-NEXT: llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> +// Element @1 -> element @3 +// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK-NEXT: llvm.mlir.constant(3 : index) : !llvm.i64 +// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> +// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]"> + + return +}