diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -242,6 +242,73 @@ rewriter); } + // tosa::CastOp + if (isa(op)) { + Type sourceType = elementTy; + Type targetType = resultTypes.front(); + + // A boolean value is considered to be unsigned when converting to + // floating-point. Otherwise, it will become `-1`. + if (sourceType.isInteger(/*width=*/1) && + mlir::UIToFPOp::areCastCompatible(sourceType, targetType)) { + return rewriter.create(loc, resultTypes, args, + mlir::None); + } else if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) { + return rewriter.create(loc, resultTypes, args, + mlir::None); + } else if (sourceType.isa() && targetType.isa()) { + FloatType src = sourceType.cast(); + FloatType res = targetType.cast(); + if (src.getWidth() > res.getWidth()) { + return rewriter.create(loc, resultTypes, args, + mlir::None); + } else if (src.getWidth() < res.getWidth()) { + return rewriter.create(loc, resultTypes, args, + mlir::None); + } + // No conversion is needed for the same width floats + return args.front(); + } + if (targetType.isInteger(/*width=*/1)) { + // When casting to bool, we need to compare whether the value is equal to + // zero. + if (sourceType.isSignlessInteger()) { + Value zeroIntVal = rewriter.create<::mlir::ConstantIntOp>( + loc, 0, sourceType.cast().getWidth()); + return rewriter.create(loc, CmpIPredicate::ne, + args.front(), zeroIntVal); + } else if (sourceType.isa()) { + Value zero = rewriter.create( + loc, rewriter.getFloatAttr(sourceType, 0.0)); + return rewriter.create(loc, CmpFPredicate::UNE, + args.front(), zero); + } + } + if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) { + IntegerType src = sourceType.cast(); + IntegerType res = targetType.cast(); + if (src.getWidth() > res.getWidth()) { + return rewriter.create(loc, resultTypes, args, + mlir::None); + } else if (src.getWidth() == 1) { + // Special case boolean values, so they get casted to `1` instead of + // `-1`. + return rewriter.create(loc, resultTypes, args, + mlir::None); + } else if (src.getWidth() < res.getWidth()) { + return rewriter.create(loc, resultTypes, args, + mlir::None); + } + // No conversion is needed for the same width integers + return args.front(); + } + if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) { + return rewriter.create(loc, resultTypes, args, + mlir::None); + } + return nullptr; + } + (void)rewriter.notifyMatchFailure( op, "unhandled op for linalg body calculation for elementwise op"); return nullptr; @@ -669,7 +736,7 @@ PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, - PointwiseConverter, + PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -180,6 +180,35 @@ // CHECK: select %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32> + // CHECK: linalg.generic + // CHECK: fptosi + %19 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32> + + // CHECK: linalg.generic + // CHECK: constant 0 + // CHECK: cmpf + %20 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1> + + // CHECK: linalg.generic + // CHECK: fptrunc + %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16> + + // CHECK: linalg.generic + // CHECK: yield + %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32> + + return +} + +// ----- + +// CHECK-LABEL: @test_simple_f16 +func @test_simple_f16(%arg0: tensor<1xf16>) -> () { + + // CHECK: linalg.generic + // CHECK: fpext + %0 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xf32> + return } @@ -255,6 +284,27 @@ // CHECK: select %15 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: linalg.generic + // CHECK: trunci + %16 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16> + + // CHECK: linalg.generic + // CHECK: yield + %17 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32> + + // CHECK: linalg.generic + // CHECK: sexti + %18 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64> + + // CHECK: linalg.generic + // CHECK: constant 0 + // CHECK: cmpi + %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1> + + // CHECK: linalg.generic + // CHECK: sitofp + %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32> + return }