Index: mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td +++ mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td @@ -39,6 +39,13 @@ their associated pointee type as consistently as possible. }]; let constructor = "::mlir::LLVM::createTypeConsistencyPass()"; + + let options = [ + Option<"maxVectorStoreSplitSize", "max-vector-store-split-size", "unsigned", + /*default=*/"512", + "Maximum size in bits of a vector value in a store operation storing" + " to multiple elements that should still be split">, + ]; } def NVVMOptimizeForTarget : Pass<"llvm-optimize-for-nvvm-target"> { Index: mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h +++ mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h @@ -53,13 +53,17 @@ PatternRewriter &rewriter) const override; }; -/// Splits stores of integers which write into multiple adjacent stores -/// of a pointer. The integer is then split and stores are generated for -/// every field being stored in a type-consistent manner. -/// This is currently done on a best-effort basis. -class SplitIntegerStores : public OpRewritePattern { +/// Splits stores which write into multiple adjacent elements of an aggregate +/// through a pointer. Currently, integers and vector are split and stores +/// are generated for every element being stored to in a type-consistent manner. +/// This is done on a best-effort basis. +class SplitStores : public OpRewritePattern { + unsigned maxVectorStoreSplitSize; + public: - using OpRewritePattern::OpRewritePattern; + SplitStores(MLIRContext *context, unsigned int maxVectorStoreSplitSize) + : OpRewritePattern(context), + maxVectorStoreSplitSize(maxVectorStoreSplitSize) {} LogicalResult matchAndRewrite(StoreOp store, PatternRewriter &rewrite) const override; Index: mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp +++ mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp @@ -411,12 +411,11 @@ return body.take_front(exclusiveEnd); } -LogicalResult -SplitIntegerStores::matchAndRewrite(StoreOp store, - PatternRewriter &rewriter) const { - IntegerType sourceType = dyn_cast(store.getValue().getType()); - if (!sourceType) { - // We currently only support integer sources. +LogicalResult SplitStores::matchAndRewrite(StoreOp store, + PatternRewriter &rewriter) const { + Type sourceType = store.getValue().getType(); + if (!isa(sourceType)) { + // We currently only support integer and vector sources. return failure(); } @@ -465,6 +464,44 @@ if (failed(writtenToFields)) return failure(); + if (writtenToFields->size() <= 1) { + // Other patterns should take care of this case, we are only interested in + // splitting field stores. + return failure(); + } + + // Vector types are simply split into its elements and new stores generated + // with those. + // Subsequent pattern applications will split these stores further if + // required. + if (auto vectorType = dyn_cast(sourceType)) { + + // Add a reasonable bound to not split very large vectors that would end up + // generating lots of code. + if (dataLayout.getTypeSizeInBits(sourceType) > maxVectorStoreSplitSize) + return failure(); + + // Extract every element in the vector and store it in the given address. + for (size_t index : llvm::seq(0, vectorType.getNumElements())) { + auto pos = rewriter.create(store.getLoc(), + rewriter.getI32IntegerAttr(index)); + auto extractOp = rewriter.create(store.getLoc(), + store.getValue(), pos); + + // For convenience, we do indexing with the element type of the vector. + // Other patterns will turn this into a type-consistent GEP. + auto gepOp = rewriter.create(store.getLoc(), address.getType(), + vectorType.getElementType(), address, + ArrayRef{index}); + + rewriter.create(store.getLoc(), extractOp, gepOp); + } + + rewriter.eraseOp(store); + + return success(); + } + unsigned currentOffset = offset; for (Type type : *writtenToFields) { unsigned fieldSize = dataLayout.getTypeSize(type); @@ -518,7 +555,7 @@ rewritePatterns.add>( &getContext()); rewritePatterns.add(&getContext()); - rewritePatterns.add(&getContext()); + rewritePatterns.add(&getContext(), maxVectorStoreSplitSize); FrozenRewritePatternSet frozen(std::move(rewritePatterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen))) Index: mlir/test/Dialect/LLVMIR/type-consistency.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/type-consistency.mlir +++ mlir/test/Dialect/LLVMIR/type-consistency.mlir @@ -300,3 +300,67 @@ // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] llvm.return } + +// ----- + +// CHECK-LABEL: llvm.func @vector_write_split +// CHECK-SAME: %[[ARG:.*]]: vector<4xi32> +llvm.func @vector_write_split(%arg: vector<4xi32>) { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32, i32)> : (i32) -> !llvm.ptr + + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)> + // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST0]] : i32] : vector<4xi32> + // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr + + // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST1]] : i32] : vector<4xi32> + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)> + // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr + + // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST2]] : i32] : vector<4xi32> + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)> + // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr + + // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST3]] : i32] : vector<4xi32> + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)> + // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr + + llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @type_consistent_vector_store +// CHECK-SAME: %[[ARG:.*]]: vector<4xi32> +llvm.func @type_consistent_vector_store(%arg: vector<4xi32>) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (vector<4xi32>)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xi32>)> : (i32) -> !llvm.ptr + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (vector<4xi32>)> + // CHECK: llvm.store %[[ARG]], %[[GEP]] + llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @type_consistent_vector_store_other_type +// CHECK-SAME: %[[ARG:.*]]: vector<4xi32> +llvm.func @type_consistent_vector_store_other_type(%arg: vector<4xi32>) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : vector<4xi32> to vector<4xf32> + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (vector<4xf32>)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xf32>)> : (i32) -> !llvm.ptr + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (vector<4xf32>)> + // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]] + llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +}