diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -22,6 +22,8 @@ #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/Functional.h" +#include "mlir/Support/MathExtras.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" @@ -897,6 +899,62 @@ } }; +// Consume tuple elements extracted by tuple get operations. +// Note that this is a conversion pattern, since we must also +// rewrite tuples introduced by other rules in the lowering. +// +// TODO(ajcbik): express this through general folding? +// +class VectorTupleGetOpConversion : public LLVMOpLowering { +public: + explicit VectorTupleGetOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::TupleGetOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto tupleGetOp = cast(op); + // Rewrite: + // %t = vector.tuple .., %e_i, .. + // %x = vector.tuple_get %t, i + // into: + // %t = vector.tuple .., %e_i, .. // one less use + // %x = %e_i + if (auto tupleOp = dyn_cast_or_null(operands[0].getDefiningOp())) { + rewriter.replaceOp(tupleGetOp, + tupleOp.getOperands()[tupleGetOp.getIndex()]); + return matchSuccess(); + } + return matchFailure(); + } +}; + +// The removal of a vector tuple will fail until all canonicalizations +// occured that remove its uses (viz. all individual elements are +// "consumed" somewhere), at which point it will apply. +// +// NOTE: not foolproof, since newly introduced vector tuples have +// an empty use list; such tuples need to be legalized somehow anyway, +// and unconditionaly erasing is the right way (even though this may +// assert at the final rewriting if use removal rules are missing). +// +// TODO(ajcbik): avoid all this by simply relying on DCE +// +class VectorTupleOpConversion : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TupleOp op, + PatternRewriter &rewriter) const override { + if (!op.use_empty()) + return matchFailure(); + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + /// Progressive lowering of StridedSliceOp to either: /// 1. extractelement + insertelement for the 1-D case /// 2. extract + optional strided_slice + insert for the n-D case. @@ -950,20 +1008,91 @@ } }; +/// Progressive lowering of ExtractSlicesOp to tuple of StridedSliceOp. +/// One: +/// %x = vector.extract_slices +/// is replaced by: +/// %a = vector.strided_slice +/// %b = vector.strided_slice +/// .. +/// %x = vector.tuple %a, %b, .. +class VectorExtractSlicesOpConversion + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + // TODO(ajcbik): refactor slice utilities out into VectorUtils.h + PatternMatchResult matchAndRewrite(ExtractSlicesOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + VectorType vectorType = op.getSourceVectorType(); + int64_t rank = vectorType.getRank(); + auto shape = vectorType.getShape(); + + SmallVector sizes; + op.getSizes(sizes); + SmallVector strides; + op.getStrides(strides); // all-ones at the moment + + // Compute the number of slices in each dimension. + SmallVector dimSliceCounts(rank); + for (int64_t r = 0; r < rank; ++r) + dimSliceCounts[r] = ceilDiv(shape[r], sizes[r]); + + // Compute the strides between slices in each dimension. + SmallVector sliceStrides(rank); + sliceStrides[rank - 1] = 1; + for (int64_t r = rank - 2; r >= 0; --r) + sliceStrides[r] = sliceStrides[r + 1] * dimSliceCounts[r + 1]; + + // For each element in the tuple, generate the proper strided slice. + TupleType tupleType = op.getResultTupleType(); + int64_t tupleSize = tupleType.size(); + SmallVector tupleValues(tupleSize); + for (int64_t i = 0; i < tupleSize; ++i) { + // Compute the vector offsets by de-linearizing the index. + SmallVector vectorOffsets(rank); + int64_t linearIndex = i; + for (int64_t r = 0; r < rank; ++r) { + vectorOffsets[r] = linearIndex / sliceStrides[r]; + linearIndex %= sliceStrides[r]; + } + // Convert from unrolled vector-space offsets to element-space offsets. + auto elementOffsets = mlir::functional::zipMap( + [](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes); + // Compute the size of each slice. + SmallVector sliceSizes(rank); + for (int64_t r = 0; r < rank; ++r) { + sliceSizes[r] = std::min(sizes[r], shape[r] - elementOffsets[r]); + } + // Insert in tuple. + tupleValues[i] = rewriter.create( + loc, op.vector(), elementOffsets, sliceSizes, strides); + } + + Value tuple = rewriter.create(loc, tupleType, tupleValues); + rewriter.replaceOp(op, tuple); + return matchSuccess(); + } +}; + } // namespace /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { MLIRContext *ctx = converter.getDialect()->getContext(); - patterns.insert(ctx); + VectorStridedSliceOpConversion, VectorTupleOpConversion>(ctx); patterns.insert(ctx, converter); + VectorPrintOpConversion, VectorTupleGetOpConversion>( + ctx, converter); } namespace { diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -424,7 +424,6 @@ // CHECK: llvm.call @print_close() : () -> () // CHECK: llvm.call @print_newline() : () -> () - func @strided_slice(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>, %arg2: vector<4x8x16xf32>) { // CHECK-LABEL: llvm.func @strided_slice( %0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> @@ -524,3 +523,19 @@ return } +func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> { + %0 = vector.extract_slices %arg0, [2, 2], [1, 1] + : vector<3x3xf32> into tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> + %1 = vector.tuple_get %0, 3 : tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> + return %1 : vector<1x1xf32> +} +// CHECK-LABEL: extract_strides(%arg0: !llvm<"[3 x <3 x float>]">) +// CHECK: %[[s0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm<"[1 x <1 x float>]"> +// CHECK: %[[s1:.*]] = llvm.extractvalue %arg0[2] : !llvm<"[3 x <3 x float>]"> +// CHECK: %[[s3:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1xf32>) : !llvm<"<1 x float>"> +// CHECK: %[[s4:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK: %[[s5:.*]] = llvm.extractelement %[[s1]][%[[s4]] : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: %[[s6:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: %[[s7:.*]] = llvm.insertelement %[[s5]], %[[s3]][%[[s6]] : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: %[[s8:.*]] = llvm.insertvalue %[[s7]], %[[s0]][0] : !llvm<"[1 x <1 x float>]"> +// CHECK: llvm.return %[[s8]] : !llvm<"[1 x <1 x float>]">