diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1727,6 +1727,51 @@ } }; +// Casts the Value `operand` to type `toType`. +static Value cast(PatternRewriter &builder, Type toType, Value operand) { + auto loc = operand.getLoc(); + if (operand.getType() == toType) + return operand; + if (auto toIntType = toType.dyn_cast()) { + // If operand is floating point, cast directly to the int type. + if (operand.getType().isa()) { + if (toIntType.isUnsigned()) + return builder.create(loc, toType, operand); + return builder.create(loc, toType, operand); + } + // Cast index operands directly to the int type. + if (operand.getType().isIndex()) + return builder.create(loc, toType, operand); + if (auto fromIntType = operand.getType().dyn_cast()) { + // Either extend or truncate. + if (toIntType.getWidth() > fromIntType.getWidth()) { + if (toIntType.isUnsigned()) + return builder.create(loc, toType, operand); + return builder.create(loc, toType, operand); + } + if (toIntType.getWidth() < fromIntType.getWidth()) + return builder.create(loc, toType, operand); + } + } else if (auto toFloatType = toType.dyn_cast()) { + // If operand is integer, cast directly to the float type. + // Note that it is unclear how to cast from BF16<->FP16. + if (auto intType = operand.getType().dyn_cast()) { + if (intType.isUnsigned()) + return builder.create(loc, toFloatType, operand); + return builder.create(loc, toFloatType, operand); + } + if (auto fromFloatType = operand.getType().dyn_cast()) { + if (toFloatType.getWidth() > fromFloatType.getWidth()) + return builder.create(loc, toFloatType, operand); + if (toFloatType.getWidth() < fromFloatType.getWidth()) + return builder.create(loc, toFloatType, operand); + } + } + emitWarning(operand.getLoc()) << "could not cast operand of type " + << operand.getType() << " to " << toType; + return nullptr; +} + /// Fold linalg.fill into linalg.generic struct FoldFillWithGenericOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1744,8 +1789,14 @@ if (!fillOp) continue; fillFound = true; + Value fillVal = fillOp.value(); + auto resultType = + fillOp.result().getType().cast().getElementType(); + Value castedVal = cast(rewriter, resultType, fillVal); + if (!castedVal) + return failure(); payload.getArgument(opOperand->getOperandNumber()) - .replaceAllUsesWith(fillOp.value()); + .replaceAllUsesWith(castedVal); } return success(fillFound); } diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -1017,6 +1017,30 @@ // ----- +// CHECK-LABEL: func @fold_fill_generic_different_dtype +// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { +// CHECK-NOT: linalg.fill +// CHECK: %[[GENERIC_OP:.*]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor) +// CHECK-SAME: outs({{.*}} : tensor) { +#map0 = affine_map<(d0) -> (d0)> +func.func @fold_fill_generic_different_dtype(%arg0: tensor) -> (tensor) { + %c0 = arith.constant 0 : index + %cst = arith.constant 7.0 : f32 + %0 = tensor.dim %arg0, %c0 : tensor + %1 = tensor.empty(%0) : tensor + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor + %3 = tensor.empty(%0) : tensor + %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor, tensor) outs (%3:tensor) { + ^bb0(%arg1: f16, %arg2: f16, %arg3: f16): + %5 = arith.addf %arg1, %arg2 : f16 + linalg.yield %5 : f16 + } -> tensor + return %4 : tensor +} + +// ----- + // CHECK-LABEL: func @fold_fill_generic_mixedaccess // CHECK-NOT: linalg.fill // CHECK: %[[GENERIC_OP:.*]] = linalg.generic