diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h --- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h @@ -68,6 +68,17 @@ PatternRewriter &rewrite) const override; }; +/// Transforms type-inconsistent stores, aka stores where the type hint of +/// the address contradicts the value stored, by inserting a bitcast if +/// possible. +class BitcastStores : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(StoreOp store, + PatternRewriter &rewriter) const override; +}; + } // namespace LLVM } // namespace mlir diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp --- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp @@ -43,7 +43,7 @@ } /// Checks that two types are the same or can be bitcast into one another. -static bool areCastCompatible(DataLayout &layout, Type lhs, Type rhs) { +static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs) { return lhs == rhs || (!isa(lhs) && !isa(rhs) && layout.getTypeSize(lhs) == layout.getTypeSize(rhs)); @@ -104,7 +104,7 @@ if (!firstType) return failure(); DataLayout layout = DataLayout::closest(load); - if (!areCastCompatible(layout, firstType, load.getResult().getType())) + if (!areBitcastCompatible(layout, firstType, load.getResult().getType())) return failure(); insertFieldIndirection(load, rewriter, inconsistentElementType); @@ -144,20 +144,13 @@ DataLayout layout = DataLayout::closest(store); // Check that the first field has the right type or can at least be bitcast // to the right type. - if (!areCastCompatible(layout, firstType, store.getValue().getType())) + if (!areBitcastCompatible(layout, firstType, store.getValue().getType())) return failure(); insertFieldIndirection(store, rewriter, inconsistentElementType); - Value replaceValue = store.getValue(); - if (firstType != store.getValue().getType()) { - rewriter.setInsertionPointAfterValue(store.getValue()); - replaceValue = rewriter.create(store->getLoc(), firstType, - store.getValue()); - } - rewriter.updateRootInPlace( - store, [&]() { store.getValueMutable().assign(replaceValue); }); + store, [&]() { store.getValueMutable().assign(store.getValue()); }); return success(); } @@ -458,12 +451,6 @@ IntegerType fieldIntType = rewriter.getIntegerType(fieldSize * 8); Value valueToStore = rewriter.create(loc, fieldIntType, shrOp); - if (fieldIntType != type) { - // Bitcast to the right type. `fieldIntType` was explicitly created - // to be of the same size as `type` and must currently be a primitive as - // well. - valueToStore = rewriter.create(loc, type, valueToStore); - } // We create an `i8` indexed GEP here as that is the easiest (offset is // already known). Other patterns turn this into a type-consistent GEP. @@ -558,6 +545,26 @@ return success(); } +LogicalResult BitcastStores::matchAndRewrite(StoreOp store, + PatternRewriter &rewriter) const { + Type sourceType = store.getValue().getType(); + Type typeHint = isElementTypeInconsistent(store.getAddr(), sourceType); + if (!typeHint) { + // Nothing to do, since it is already consistent. + return failure(); + } + + auto dataLayout = DataLayout::closest(store); + if (!areBitcastCompatible(dataLayout, typeHint, sourceType)) + return failure(); + + auto bitcastOp = + rewriter.create(store.getLoc(), typeHint, store.getValue()); + rewriter.updateRootInPlace( + store, [&] { store.getValueMutable().assign(bitcastOp); }); + return success(); +} + //===----------------------------------------------------------------------===// // Type consistency pass //===----------------------------------------------------------------------===// @@ -572,6 +579,7 @@ &getContext()); rewritePatterns.add(&getContext()); rewritePatterns.add(&getContext(), maxVectorSplitSize); + rewritePatterns.add(&getContext()); FrozenRewritePatternSet frozen(std::move(rewritePatterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen))) diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir --- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir +++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir @@ -218,8 +218,8 @@ // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]] // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64 // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 - // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32 // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (f32, f32)> + // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32 // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]] llvm.store %arg, %1 : i64, !llvm.ptr // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] @@ -409,12 +409,27 @@ // CHECK-SAME: %[[ARG:.*]]: vector<4xi32> llvm.func @type_consistent_vector_store_other_type(%arg: vector<4xi32>) { %0 = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : vector<4xi32> to vector<4xf32> // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (vector<4xf32>)> %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xf32>)> : (i32) -> !llvm.ptr // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (vector<4xf32>)> + // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : vector<4xi32> to vector<4xf32> // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]] llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] llvm.return } + +// ----- + +// CHECK-LABEL: llvm.func @bitcast_insertion +// CHECK-SAME: %[[ARG:.*]]: i32 +llvm.func @bitcast_insertion(%arg: i32) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x f32 + %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr + // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : i32 to f32 + // CHECK: llvm.store %[[BIT_CAST]], %[[ALLOCA]] + llvm.store %arg, %1 : i32, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +}