diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -92,6 +92,12 @@ getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, ArrayRef valueOrAttrVec); +/// Converts a scalar value `operand` to type `toType`. If the value doesn't +/// convert, a warning will be issued and the operand is returned as is (which +/// will presumably yield a verification issue downstream). +Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, + Type toType, bool isUnsignedCast); + /// Helper struct to build simple arithmetic quantities with minimal type /// inference support. struct ArithBuilder { diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -80,6 +80,50 @@ return b.create(loc, targetIntegerType, value); } +Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand, + Type toType, bool isUnsignedCast) { + if (operand.getType() == toType) + return operand; + if (auto toIntType = toType.dyn_cast()) { + // If operand is floating point, cast directly to the int type. + if (operand.getType().isa()) { + if (isUnsignedCast) + return b.create(loc, toType, operand); + return b.create(loc, toType, operand); + } + // Cast index operands directly to the int type. + if (operand.getType().isIndex()) + return b.create(loc, toType, operand); + if (auto fromIntType = operand.getType().dyn_cast()) { + // Either extend or truncate. + if (toIntType.getWidth() > fromIntType.getWidth()) { + if (isUnsignedCast) + return b.create(loc, toType, operand); + return b.create(loc, toType, operand); + } + if (toIntType.getWidth() < fromIntType.getWidth()) + return b.create(loc, toType, operand); + } + } else if (auto toFloatType = toType.dyn_cast()) { + // If operand is integer, cast directly to the float type. + // Note that it is unclear how to cast from BF16<->FP16. + if (operand.getType().isa()) { + if (isUnsignedCast) + return b.create(loc, toFloatType, operand); + return b.create(loc, toFloatType, operand); + } + if (auto fromFloatType = operand.getType().dyn_cast()) { + if (toFloatType.getWidth() > fromFloatType.getWidth()) + return b.create(loc, toFloatType, operand); + if (toFloatType.getWidth() < fromFloatType.getWidth()) + return b.create(loc, toFloatType, operand); + } + } + emitWarning(loc) << "could not cast operand of type " << operand.getType() + << " to " << toType; + return operand; +} + SmallVector mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, ArrayRef valueOrAttrVec) { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -423,48 +423,7 @@ Value cast(Type toType, Value operand, bool isUnsignedCast) { OpBuilder builder = getBuilder(); auto loc = operand.getLoc(); - - if (operand.getType() == toType) - return operand; - if (auto toIntType = toType.dyn_cast()) { - // If operand is floating point, cast directly to the int type. - if (operand.getType().isa()) { - if (isUnsignedCast) - return builder.create(loc, toType, operand); - return builder.create(loc, toType, operand); - } - // Cast index operands directly to the int type. - if (operand.getType().isIndex()) - return builder.create(loc, toType, operand); - if (auto fromIntType = operand.getType().dyn_cast()) { - // Either extend or truncate. - if (toIntType.getWidth() > fromIntType.getWidth()) { - if (isUnsignedCast) - return builder.create(loc, toType, operand); - return builder.create(loc, toType, operand); - } - if (toIntType.getWidth() < fromIntType.getWidth()) - return builder.create(loc, toType, operand); - } - } else if (auto toFloatType = toType.dyn_cast()) { - // If operand is integer, cast directly to the float type. - // Note that it is unclear how to cast from BF16<->FP16. - if (operand.getType().isa()) { - if (isUnsignedCast) - return builder.create(loc, toFloatType, operand); - return builder.create(loc, toFloatType, operand); - } - if (auto fromFloatType = operand.getType().dyn_cast()) { - if (toFloatType.getWidth() > fromFloatType.getWidth()) - return builder.create(loc, toFloatType, operand); - if (toFloatType.getWidth() < fromFloatType.getWidth()) - return builder.create(loc, toFloatType, operand); - } - } - - emitWarning(operand.getLoc()) << "could not cast operand of type " - << operand.getType() << " to " << toType; - return operand; + return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast); } bool isComplex(Value value) { return value.getType().isa(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1744,8 +1744,14 @@ if (!fillOp) continue; fillFound = true; + Value fillVal = fillOp.value(); + auto resultType = + fillOp.result().getType().cast().getElementType(); + Value convertedVal = + convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType, + /*isUnsignedCast =*/false); payload.getArgument(opOperand->getOperandNumber()) - .replaceAllUsesWith(fillOp.value()); + .replaceAllUsesWith(convertedVal); } return success(fillFound); } diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -1017,6 +1017,30 @@ // ----- +// CHECK-LABEL: func @fold_fill_generic_different_dtype +// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { +// CHECK-NOT: linalg.fill +// CHECK: %[[GENERIC_OP:.*]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor) +// CHECK-SAME: outs({{.*}} : tensor) { +#map0 = affine_map<(d0) -> (d0)> +func.func @fold_fill_generic_different_dtype(%arg0: tensor) -> (tensor) { + %c0 = arith.constant 0 : index + %cst = arith.constant 7.0 : f32 + %0 = tensor.dim %arg0, %c0 : tensor + %1 = tensor.empty(%0) : tensor + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor + %3 = tensor.empty(%0) : tensor + %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor, tensor) outs (%3:tensor) { + ^bb0(%arg1: f16, %arg2: f16, %arg3: f16): + %5 = arith.addf %arg1, %arg2 : f16 + linalg.yield %5 : f16 + } -> tensor + return %4 : tensor +} + +// ----- + // CHECK-LABEL: func @fold_fill_generic_mixedaccess // CHECK-NOT: linalg.fill // CHECK: %[[GENERIC_OP:.*]] = linalg.generic