Index: mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td +++ mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td @@ -39,6 +39,13 @@ their associated pointee type as consistently as possible. }]; let constructor = "::mlir::LLVM::createTypeConsistencyPass()"; + + let options = [ + Option<"maxVectorSplitSize", "max-vector-split-size", "unsigned", + /*default=*/"512", + "Maximum size in bits of a vector value in a load or store operation" + " operating on multiple elements that should still be split">, + ]; } def NVVMOptimizeForTarget : Pass<"llvm-optimize-for-nvvm-target"> { Index: mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h +++ mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h @@ -53,13 +53,16 @@ PatternRewriter &rewriter) const override; }; -/// Splits stores of integers which write into multiple adjacent stores -/// of a pointer. The integer is then split and stores are generated for -/// every field being stored in a type-consistent manner. -/// This is currently done on a best-effort basis. -class SplitIntegerStores : public OpRewritePattern { +/// Splits stores which write into multiple adjacent elements of an aggregate +/// through a pointer. Currently, integers and vector are split and stores +/// are generated for every element being stored to in a type-consistent manner. +/// This is done on a best-effort basis. +class SplitStores : public OpRewritePattern { + unsigned maxVectorSplitSize; + public: - using OpRewritePattern::OpRewritePattern; + SplitStores(MLIRContext *context, unsigned maxVectorSplitSize) + : OpRewritePattern(context), maxVectorSplitSize(maxVectorSplitSize) {} LogicalResult matchAndRewrite(StoreOp store, PatternRewriter &rewrite) const override; Index: mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp +++ mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp @@ -411,12 +411,78 @@ return body.take_front(exclusiveEnd); } -LogicalResult -SplitIntegerStores::matchAndRewrite(StoreOp store, - PatternRewriter &rewriter) const { - IntegerType sourceType = dyn_cast(store.getValue().getType()); - if (!sourceType) { - // We currently only support integer sources. +/// Splits a store of the vector `value` into `address` at `storeOffset` into +/// multiple stores of each element with the goal of each generated store +/// becoming type-consistent through subsequent pattern applications. +static void splitVectorStore(const DataLayout &dataLayout, Location loc, + RewriterBase &rewriter, Value address, + TypedValue value, + unsigned storeOffset) { + VectorType vectorType = value.getType(); + unsigned elementSize = dataLayout.getTypeSize(vectorType.getElementType()); + + // Extract every element in the vector and store it in the given address. + for (size_t index : llvm::seq(0, vectorType.getNumElements())) { + auto pos = + rewriter.create(loc, rewriter.getI32IntegerAttr(index)); + auto extractOp = rewriter.create(loc, value, pos); + + // For convenience, we do indexing by calculating the final byte offset. + // Other patterns will turn this into a type-consistent GEP. + auto gepOp = rewriter.create( + loc, address.getType(), rewriter.getI8Type(), address, + ArrayRef{storeOffset + index * elementSize}); + + rewriter.create(loc, extractOp, gepOp); + } +} + +/// Splits a store of the integer `value` into `address` at `storeOffset` into +/// multiple stores to each 'writtenFields', making each store operation +/// type-consistent. +static void splitIntegerStore(const DataLayout &dataLayout, Location loc, + RewriterBase &rewriter, Value address, + Value value, unsigned storeOffset, + ArrayRef writtenToFields) { + unsigned currentOffset = storeOffset; + for (Type type : writtenToFields) { + unsigned fieldSize = dataLayout.getTypeSize(type); + + // Extract the data out of the integer by first shifting right and then + // truncating it. + auto pos = rewriter.create( + loc, rewriter.getIntegerAttr(value.getType(), + (currentOffset - storeOffset) * 8)); + + auto shrOp = rewriter.create(loc, value, pos); + + IntegerType fieldIntType = rewriter.getIntegerType(fieldSize * 8); + Value valueToStore = rewriter.create(loc, fieldIntType, shrOp); + if (fieldIntType != type) { + // Bitcast to the right type. `fieldIntType` was explicitly created + // to be of the same size as `type` and must currently be a primitive as + // well. + valueToStore = rewriter.create(loc, type, valueToStore); + } + + // We create an `i8` indexed GEP here as that is the easiest (offset is + // already known). Other patterns turn this into a type-consistent GEP. + auto gepOp = + rewriter.create(loc, address.getType(), rewriter.getI8Type(), + address, ArrayRef{currentOffset}); + rewriter.create(loc, valueToStore, gepOp); + + // No need to care about padding here since we already checked previously + // that no padding exists in this range. + currentOffset += fieldSize; + } +} + +LogicalResult SplitStores::matchAndRewrite(StoreOp store, + PatternRewriter &rewriter) const { + Type sourceType = store.getValue().getType(); + if (!isa(sourceType)) { + // We currently only support integer and vector sources. return failure(); } @@ -465,43 +531,30 @@ if (failed(writtenToFields)) return failure(); - unsigned currentOffset = offset; - for (Type type : *writtenToFields) { - unsigned fieldSize = dataLayout.getTypeSize(type); - - // Extract the data out of the integer by first shifting right and then - // truncating it. - auto pos = rewriter.create( - store.getLoc(), - rewriter.getIntegerAttr(sourceType, (currentOffset - offset) * 8)); - - auto shrOp = rewriter.create(store.getLoc(), store.getValue(), pos); - - IntegerType fieldIntType = rewriter.getIntegerType(fieldSize * 8); - Value valueToStore = - rewriter.create(store.getLoc(), fieldIntType, shrOp); - if (fieldIntType != type) { - // Bitcast to the right type. `fieldIntType` was explicitly created - // to be of the same size as `type` and must currently be a primitive as - // well. - valueToStore = - rewriter.create(store.getLoc(), type, valueToStore); - } - - // We create an `i8` indexed GEP here as that is the easiest (offset is - // already known). Other patterns turn this into a type-consistent GEP. - auto gepOp = rewriter.create(store.getLoc(), address.getType(), - rewriter.getI8Type(), address, - ArrayRef{currentOffset}); - rewriter.create(store.getLoc(), valueToStore, gepOp); + if (writtenToFields->size() <= 1) { + // Other patterns should take care of this case, we are only interested in + // splitting field stores. + return failure(); + } - // No need to care about padding here since we already checked previously - // that no padding exists in this range. - currentOffset += fieldSize; + if (isa(sourceType)) { + splitIntegerStore(dataLayout, store.getLoc(), rewriter, address, + store.getValue(), offset, *writtenToFields); + rewriter.eraseOp(store); + return success(); } - rewriter.eraseOp(store); + // Add a reasonable bound to not split very large vectors that would end up + // generating lots of code. + if (dataLayout.getTypeSizeInBits(sourceType) > maxVectorSplitSize) + return failure(); + // Vector types are simply split into its elements and new stores generated + // with those. Subsequent pattern applications will split these stores further + // if required. + splitVectorStore(dataLayout, store.getLoc(), rewriter, address, + cast>(store.getValue()), offset); + rewriter.eraseOp(store); return success(); } @@ -518,7 +571,7 @@ rewritePatterns.add>( &getContext()); rewritePatterns.add(&getContext()); - rewritePatterns.add(&getContext()); + rewritePatterns.add(&getContext(), maxVectorSplitSize); FrozenRewritePatternSet frozen(std::move(rewritePatterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen))) Index: mlir/test/Dialect/LLVMIR/type-consistency.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/type-consistency.mlir +++ mlir/test/Dialect/LLVMIR/type-consistency.mlir @@ -300,3 +300,121 @@ // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] llvm.return } + +// ----- + +// CHECK-LABEL: llvm.func @vector_write_split +// CHECK-SAME: %[[ARG:.*]]: vector<4xi32> +llvm.func @vector_write_split(%arg: vector<4xi32>) { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32, i32)> : (i32) -> !llvm.ptr + + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)> + // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST0]] : i32] : vector<4xi32> + // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr + + // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST1]] : i32] : vector<4xi32> + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)> + // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr + + // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST2]] : i32] : vector<4xi32> + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)> + // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr + + // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST3]] : i32] : vector<4xi32> + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32, i32)> + // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr + + llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @vector_write_split_offset +// CHECK-SAME: %[[ARG:.*]]: vector<4xi32> +llvm.func @vector_write_split_offset(%arg: vector<4xi32>) { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i64, i32, i32, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32, i32, i32, i32)> : (i32) -> !llvm.ptr + %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32, i32, i32)> + + // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST0]] : i32] : vector<4xi32> + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32, i32, i32)> + // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr + + // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST1]] : i32] : vector<4xi32> + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32, i32, i32)> + // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr + + // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST2]] : i32] : vector<4xi32> + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32, i32, i32)> + // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr + + // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST3]] : i32] : vector<4xi32> + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 4] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32, i32, i32)> + // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] : i32, !llvm.ptr + + llvm.store %arg, %2 : vector<4xi32>, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// Small test that a split vector store will be further optimized (to than e.g. +// split integer loads to structs as shown here) + +// CHECK-LABEL: llvm.func @vector_write_split_struct +// CHECK-SAME: %[[ARG:.*]]: vector<2xi64> +llvm.func @vector_write_split_struct(%arg: vector<2xi64>) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32, i32)> : (i32) -> !llvm.ptr + + // CHECK-COUNT-4: llvm.store %{{.*}}, %{{.*}} : i32, !llvm.ptr + + llvm.store %arg, %1 : vector<2xi64>, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @type_consistent_vector_store +// CHECK-SAME: %[[ARG:.*]]: vector<4xi32> +llvm.func @type_consistent_vector_store(%arg: vector<4xi32>) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (vector<4xi32>)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xi32>)> : (i32) -> !llvm.ptr + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (vector<4xi32>)> + // CHECK: llvm.store %[[ARG]], %[[GEP]] + llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @type_consistent_vector_store_other_type +// CHECK-SAME: %[[ARG:.*]]: vector<4xi32> +llvm.func @type_consistent_vector_store_other_type(%arg: vector<4xi32>) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : vector<4xi32> to vector<4xf32> + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (vector<4xf32>)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xf32>)> : (i32) -> !llvm.ptr + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (vector<4xf32>)> + // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]] + llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +}