diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -92,6 +92,11 @@ getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, ArrayRef valueOrAttrVec); +/// Converts a scalar value `operand` to type `toType`. If the value doesn't +/// convert, a nullptr is returned with a warning. +Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, + Type toType); + /// Helper struct to build simple arithmetic quantities with minimal type /// inference support. struct ArithBuilder { diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -80,6 +80,50 @@ return b.create(loc, targetIntegerType, value); } +Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand, + Type toType) { + if (operand.getType() == toType) + return operand; + if (auto toIntType = toType.dyn_cast()) { + // If operand is floating point, cast directly to the int type. + if (operand.getType().isa()) { + if (toIntType.isUnsigned()) + return b.create(loc, toType, operand); + return b.create(loc, toType, operand); + } + // Cast index operands directly to the int type. + if (operand.getType().isIndex()) + return b.create(loc, toType, operand); + if (auto fromIntType = operand.getType().dyn_cast()) { + // Either extend or truncate. + if (toIntType.getWidth() > fromIntType.getWidth()) { + if (toIntType.isUnsigned()) + return b.create(loc, toType, operand); + return b.create(loc, toType, operand); + } + if (toIntType.getWidth() < fromIntType.getWidth()) + return b.create(loc, toType, operand); + } + } else if (auto toFloatType = toType.dyn_cast()) { + // If operand is integer, cast directly to the float type. + // Note that it is unclear how to cast from BF16<->FP16. + if (auto intType = operand.getType().dyn_cast()) { + if (intType.isUnsigned()) + return b.create(loc, toFloatType, operand); + return b.create(loc, toFloatType, operand); + } + if (auto fromFloatType = operand.getType().dyn_cast()) { + if (toFloatType.getWidth() > fromFloatType.getWidth()) + return b.create(loc, toFloatType, operand); + if (toFloatType.getWidth() < fromFloatType.getWidth()) + return b.create(loc, toFloatType, operand); + } + } + emitWarning(loc) << "could not cast operand of type " << operand.getType() + << " to " << toType; + return nullptr; +} + SmallVector mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, ArrayRef valueOrAttrVec) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1744,8 +1744,15 @@ if (!fillOp) continue; fillFound = true; + Value fillVal = fillOp.value(); + auto resultType = + fillOp.result().getType().cast().getElementType(); + Value convertedVal = + convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType); + if (!convertedVal) + return failure(); payload.getArgument(opOperand->getOperandNumber()) - .replaceAllUsesWith(fillOp.value()); + .replaceAllUsesWith(convertedVal); } return success(fillFound); } diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -1017,6 +1017,30 @@ // ----- +// CHECK-LABEL: func @fold_fill_generic_different_dtype +// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { +// CHECK-NOT: linalg.fill +// CHECK: %[[GENERIC_OP:.*]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor) +// CHECK-SAME: outs({{.*}} : tensor) { +#map0 = affine_map<(d0) -> (d0)> +func.func @fold_fill_generic_different_dtype(%arg0: tensor) -> (tensor) { + %c0 = arith.constant 0 : index + %cst = arith.constant 7.0 : f32 + %0 = tensor.dim %arg0, %c0 : tensor + %1 = tensor.empty(%0) : tensor + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor + %3 = tensor.empty(%0) : tensor + %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor, tensor) outs (%3:tensor) { + ^bb0(%arg1: f16, %arg2: f16, %arg3: f16): + %5 = arith.addf %arg1, %arg2 : f16 + linalg.yield %5 : f16 + } -> tensor + return %4 : tensor +} + +// ----- + // CHECK-LABEL: func @fold_fill_generic_mixedaccess // CHECK-NOT: linalg.fill // CHECK: %[[GENERIC_OP:.*]] = linalg.generic