diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -1115,6 +1115,7 @@ } static StringRef getIndexAttrName() { return "index"; } }]; + let hasCanonicalizer = 1; } def Vector_PrintOp : diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -22,6 +22,8 @@ #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/Functional.h" +#include "mlir/Support/MathExtras.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" @@ -950,6 +952,153 @@ } }; +/// Progressive lowering of ExtractSlicesOp to tuple of StridedSliceOp. +/// One: +/// %x = vector.extract_slices %0 +/// is replaced by: +/// %a = vector.strided_slice %0 +/// %b = vector.strided_slice %0 +/// .. +/// %x = vector.tuple %a, %b, .. +class VectorExtractSlicesOpConversion + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + // TODO(ajcbik): refactor slice utilities out into VectorUtils.h + PatternMatchResult matchAndRewrite(ExtractSlicesOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + VectorType vectorType = op.getSourceVectorType(); + int64_t rank = vectorType.getRank(); + auto shape = vectorType.getShape(); + + SmallVector sizes; + op.getSizes(sizes); + SmallVector strides; + op.getStrides(strides); // all-ones at the moment + + // Compute the number of slices in each dimension. + SmallVector dimSliceCounts(rank); + for (int64_t r = 0; r < rank; ++r) + dimSliceCounts[r] = ceilDiv(shape[r], sizes[r]); + + // Compute the strides between slices in each dimension. + SmallVector sliceStrides(rank); + sliceStrides[rank - 1] = 1; + for (int64_t r = rank - 2; r >= 0; --r) + sliceStrides[r] = sliceStrides[r + 1] * dimSliceCounts[r + 1]; + + // For each element in the tuple, generate the proper strided slice. + TupleType tupleType = op.getResultTupleType(); + int64_t tupleSize = tupleType.size(); + SmallVector tupleValues(tupleSize); + for (int64_t i = 0; i < tupleSize; ++i) { + // Compute the vector offsets by de-linearizing the index. + SmallVector vectorOffsets(rank); + int64_t linearIndex = i; + for (int64_t r = 0; r < rank; ++r) { + vectorOffsets[r] = linearIndex / sliceStrides[r]; + linearIndex %= sliceStrides[r]; + } + // Convert from unrolled vector-space offsets to element-space offsets. + auto elementOffsets = mlir::functional::zipMap( + [](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes); + // Compute the size of each slice. + SmallVector sliceSizes(rank); + for (int64_t r = 0; r < rank; ++r) { + sliceSizes[r] = std::min(sizes[r], shape[r] - elementOffsets[r]); + } + // Insert in tuple. + tupleValues[i] = rewriter.create( + loc, op.vector(), elementOffsets, sliceSizes, strides); + } + + rewriter.replaceOpWithNewOp(op, tupleType, tupleValues); + return matchSuccess(); + } +}; + +/// Progressive lowering of InsertSlicesOp to series of InsertStridedSliceOp. +/// One: +/// %x = vector.insert_slices %0 +/// is replaced by: +/// %r0 = vector.splat 0 +// %t1 = vector.tuple_get %0, 0 +/// %r1 = vector.insert_strided_slice %r0, %t1 +// %t2 = vector.tuple_get %0, 1 +/// %r2 = vector.insert_strided_slice %r1, %t2 +/// .. +/// %x = .. +class VectorInsertSlicesOpConversion : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + // TODO(ajcbik): refactor slice utilities out into VectorUtils.h + PatternMatchResult matchAndRewrite(InsertSlicesOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + VectorType vectorType = op.getResultVectorType(); + int64_t rank = vectorType.getRank(); + auto shape = vectorType.getShape(); + + SmallVector sizes; + op.getSizes(sizes); + SmallVector strides; + op.getStrides(strides); // all-ones at the moment + + // Compute the number of slices in each dimension. + SmallVector dimSliceCounts(rank); + for (int64_t r = 0; r < rank; ++r) + dimSliceCounts[r] = ceilDiv(shape[r], sizes[r]); + + // Compute the strides between slices in each dimension. + SmallVector sliceStrides(rank); + sliceStrides[rank - 1] = 1; + for (int64_t r = rank - 2; r >= 0; --r) + sliceStrides[r] = sliceStrides[r + 1] * dimSliceCounts[r + 1]; + + // Prepare result. + auto elemType = vectorType.getElementType(); + Value zero = rewriter.create(loc, elemType, + rewriter.getZeroAttr(elemType)); + Value result = rewriter.create(loc, vectorType, zero); + + // For each element in the tuple, extract the proper strided slice. + TupleType tupleType = op.getSourceTupleType(); + int64_t tupleSize = tupleType.size(); + SmallVector tupleValues(tupleSize); + for (int64_t i = 0; i < tupleSize; ++i) { + // Compute the vector offsets by de-linearizing the index. + SmallVector vectorOffsets(rank); + int64_t linearIndex = i; + for (int64_t r = 0; r < rank; ++r) { + vectorOffsets[r] = linearIndex / sliceStrides[r]; + linearIndex %= sliceStrides[r]; + } + // Convert from unrolled vector-space offsets to element-space offsets. + auto elementOffsets = mlir::functional::zipMap( + [](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes); + // Compute the size of each slice. + SmallVector sliceSizes(rank); + for (int64_t r = 0; r < rank; ++r) { + sliceSizes[r] = std::min(sizes[r], shape[r] - elementOffsets[r]); + } + // Extract from tuple into the result. + auto index = rewriter.getI64IntegerAttr(i); + auto tupleGet = rewriter.create( + loc, tupleType.getType(i), op.getOperand(), index); + result = rewriter.create(loc, tupleGet, result, + elementOffsets, strides); + } + + rewriter.replaceOp(op, result); + return matchSuccess(); + } +}; + } // namespace /// Populate the given list with patterns that convert from Vector to LLVM. @@ -973,19 +1122,33 @@ } // namespace void LowerVectorToLLVMPass::runOnModule() { - // Convert to the LLVM IR dialect using the converter defined above. - OwningRewritePatternList patterns; - LLVMTypeConverter converter(&getContext()); - populateVectorToLLVMConversionPatterns(converter, patterns); - populateStdToLLVMConversionPatterns(converter, patterns); - - ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addDynamicallyLegalOp( - [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); - if (failed( - applyPartialConversion(getModule(), target, patterns, &converter))) { - signalPassFailure(); + // Perform progressive lowering of operations on slices. The + // canonicalization rewriting together with DCE in the greedy + // rewriter ensure that any introduced tuples remain completely + // invisible to the next lowering pass. + { + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + populateVectorToVectorCanonicalizationPatterns(patterns, &getContext()); + applyPatternsGreedily(getModule(), patterns); + } + + // Convert to the LLVM IR dialect. + { + LLVMTypeConverter converter(&getContext()); + OwningRewritePatternList patterns; + populateVectorToLLVMConversionPatterns(converter, patterns); + populateStdToLLVMConversionPatterns(converter, patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addDynamicallyLegalOp( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + if (failed(applyPartialConversion(getModule(), target, patterns, + &converter))) { + signalPassFailure(); + } } } diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -1681,6 +1681,36 @@ return success(); } +namespace { + +class TupleGetFolder : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TupleGetOp op, + PatternRewriter &rewriter) const override { + // Rewrite: + // %t = vector.tuple .., %e_i, .. + // %x = vector.tuple_get %t, i + // into: + // %t = vector.tuple .., %e_i, .. // one less use + // %x = %e_i + if (auto tupleOp = + dyn_cast_or_null(op.getOperand().getDefiningOp())) { + rewriter.replaceOp(op, tupleOp.getOperand(op.getIndex())); + return matchSuccess(); + } + return matchFailure(); + } +}; + +} // end anonymous namespace + +void TupleGetOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ConstantMaskOp //===----------------------------------------------------------------------===// @@ -1814,7 +1844,9 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); + patterns + .insert( + context); } namespace mlir { diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -424,10 +424,11 @@ // CHECK: llvm.call @print_close() : () -> () // CHECK: llvm.call @print_newline() : () -> () - -func @strided_slice(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>, %arg2: vector<4x8x16xf32>) { -// CHECK-LABEL: llvm.func @strided_slice( +func @strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> { %0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> + return %0 : vector<2xf32> +} +// CHECK-LABEL: llvm.func @strided_slice1 // CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float // CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>"> // CHECK: llvm.mlir.constant(2 : index) : !llvm.i64 @@ -439,7 +440,11 @@ // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> - %1 = vector.strided_slice %arg1 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32> +func @strided_slice2(%arg0: vector<4x8xf32>) -> vector<2x8xf32> { + %0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32> + return %0 : vector<2x8xf32> +} +// CHECK-LABEL: llvm.func @strided_slice2 // CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float // CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x8xf32>) : !llvm<"[2 x <8 x float>]"> // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"[4 x <8 x float>]"> @@ -447,7 +452,11 @@ // CHECK: llvm.extractvalue %{{.*}}[3] : !llvm<"[4 x <8 x float>]"> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"[2 x <8 x float>]"> - %2 = vector.strided_slice %arg1 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32> +func @strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> { + %0 = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32> + return %0 : vector<2x2xf32> +} +// CHECK-LABEL: llvm.func @strided_slice3 // CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float // CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm<"[2 x <2 x float>]"> // @@ -479,17 +488,19 @@ // CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> // CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <2 x float>]"> - return -} - -func @insert_strided_slice(%a: vector<2x2xf32>, %b: vector<4x4xf32>, %c: vector<4x4x4xf32>) { -// CHECK-LABEL: @insert_strided_slice - +func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> { %0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32> + return %0 : vector<4x4x4xf32> +} +// CHECK-LABEL: @insert_strided_slice1 // CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x [4 x <4 x float>]]"> // CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [4 x <4 x float>]]"> - %1 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +func @insert_strided_slice2(%a: vector<2x2xf32>, %b: vector<4x4xf32>) -> vector<4x4xf32> { + %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + return %0 : vector<4x4xf32> +} +// CHECK-LABEL: @insert_strided_slice2 // // Subvector vector<2xf32> @0 into vector<4xf32> @2 // CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <2 x float>]"> @@ -521,6 +532,19 @@ // CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> // CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]"> - return +func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> { + %0 = vector.extract_slices %arg0, [2, 2], [1, 1] + : vector<3x3xf32> into tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> + %1 = vector.tuple_get %0, 3 : tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> + return %1 : vector<1x1xf32> } - +// CHECK-LABEL: extract_strides(%arg0: !llvm<"[3 x <3 x float>]">) +// CHECK: %[[s0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm<"[1 x <1 x float>]"> +// CHECK: %[[s1:.*]] = llvm.extractvalue %arg0[2] : !llvm<"[3 x <3 x float>]"> +// CHECK: %[[s3:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1xf32>) : !llvm<"<1 x float>"> +// CHECK: %[[s4:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK: %[[s5:.*]] = llvm.extractelement %[[s1]][%[[s4]] : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: %[[s6:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: %[[s7:.*]] = llvm.insertelement %[[s5]], %[[s3]][%[[s6]] : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: %[[s8:.*]] = llvm.insertvalue %[[s7]], %[[s0]][0] : !llvm<"[1 x <1 x float>]"> +// CHECK: llvm.return %[[s8]] : !llvm<"[1 x <1 x float>]">