diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -289,6 +289,67 @@ rewriter); } + // tosa::CastOp + if (isa(op)) { + Type srcTy = elementTy; + Type dstTy = resultTypes.front(); + bool bitExtend = + srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth(); + + if (srcTy == dstTy) + return args.front(); + + if (srcTy.isa() && dstTy.isa() && bitExtend) + return rewriter.create(loc, resultTypes, args, mlir::None); + + if (srcTy.isa() && dstTy.isa() && !bitExtend) + return rewriter.create(loc, resultTypes, args, + mlir::None); + + // 1-bit integers need to be treated as signless. + if (srcTy.isInteger(1) && mlir::UIToFPOp::areCastCompatible(srcTy, dstTy)) + return rewriter.create(loc, resultTypes, args, + mlir::None); + + if (srcTy.isInteger(1) && dstTy.isa() && bitExtend) + return rewriter.create(loc, resultTypes, args, + mlir::None); + + // All other si-to-fp conversions should be handled by SIToFP. + if (mlir::SIToFPOp::areCastCompatible(srcTy, dstTy)) + return rewriter.create(loc, resultTypes, args, + mlir::None); + + // Casting to boolean, floats need to only be checked as not-equal to zero. + if (srcTy.isa() && dstTy.isInteger(1)) { + Value zero = + rewriter.create(loc, rewriter.getFloatAttr(srcTy, 0.0)); + return rewriter.create(loc, CmpFPredicate::UNE, + args.front(), zero); + } + + if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy)) + return rewriter.create(loc, resultTypes, args, + mlir::None); + + // Casting to boolean, integers need to only be checked as not-equal to + // zero. + if (srcTy.isa() && dstTy.isInteger(1)) { + Value zero = + rewriter.create(loc, 0, srcTy.getIntOrFloatBitWidth()); + return rewriter.create(loc, CmpIPredicate::ne, args.front(), + zero); + } + + if (srcTy.isa() && dstTy.isa() && bitExtend) + return rewriter.create(loc, resultTypes, args, + mlir::None); + + if (srcTy.isa() && dstTy.isa() && !bitExtend) + return rewriter.create(loc, resultTypes, args, + mlir::None); + } + (void)rewriter.notifyMatchFailure( op, "unhandled op for linalg body calculation for elementwise op"); return nullptr; @@ -891,7 +952,7 @@ PointwiseConverter, PointwiseConverter, PointwiseConverter, - PointwiseConverter, + PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -180,6 +180,35 @@ // CHECK: select %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32> + // CHECK: linalg.generic + // CHECK: fptosi + %19 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32> + + // CHECK: linalg.generic + // CHECK: constant 0 + // CHECK: cmpf + %20 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1> + + // CHECK: linalg.generic + // CHECK: fptrunc + %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16> + + // CHECK: linalg.generic + // CHECK: yield + %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32> + + return +} + +// ----- + +// CHECK-LABEL: @test_simple_f16 +func @test_simple_f16(%arg0: tensor<1xf16>) -> () { + + // CHECK: linalg.generic + // CHECK: fpext + %0 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xf32> + return } @@ -255,6 +284,27 @@ // CHECK: select %15 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: linalg.generic + // CHECK: trunci + %16 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16> + + // CHECK: linalg.generic + // CHECK: yield + %17 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32> + + // CHECK: linalg.generic + // CHECK: sexti + %18 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64> + + // CHECK: linalg.generic + // CHECK: constant 0 + // CHECK: cmpi + %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1> + + // CHECK: linalg.generic + // CHECK: sitofp + %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32> + return }