diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -227,6 +227,45 @@ if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); + // tosa::ArithmeticRightShiftOp + if (isa(op) && elementTy.isa()) { + auto result = + rewriter.create(loc, resultTypes, args); + auto round = op->getAttr("round").cast().getValue(); + if (!round) { + return result; + } + + Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1); + auto one = + rewriter.create(loc, IntegerAttr::get(elementTy, 1)); + auto zero = + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + auto i1one = + rewriter.create(loc, IntegerAttr::get(i1Ty, 1)); + + // Checking that input2 != 0 + auto shiftValueGreaterThanZero = + rewriter.create(loc, CmpIPredicate::sgt, args[1], zero); + + // Checking for the last bit of input1 to be 1 + auto subtract = + rewriter.create(loc, resultTypes, args[1], one); + auto shifted = rewriter + .create(loc, resultTypes, + args[0], subtract) + ->getResults(); + auto truncated = + rewriter.create(loc, i1Ty, shifted, mlir::None); + auto isInputOdd = rewriter.create(loc, i1Ty, truncated, i1one); + + auto shouldRound = rewriter.create( + loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); + auto extended = + rewriter.create(loc, resultTypes, shouldRound); + return rewriter.create(loc, resultTypes, result, extended); + } + // tosa::LogicalAnd if (isa(op) && elementTy.isInteger(1)) return rewriter.create(loc, resultTypes, args); @@ -284,6 +323,15 @@ return rewriter.create(loc, CmpIPredicate::sge, args[0], args[1]); + // tosa::EqualOp + if (isa(op) && elementTy.isa()) + return rewriter.create(loc, CmpFPredicate::OEQ, args[0], + args[1]); + + if (isa(op) && elementTy.isSignlessInteger()) + return rewriter.create(loc, CmpIPredicate::eq, args[0], + args[1]); + // tosa::SelectOp if (isa(op)) { elementTy = op->getOperand(1).getType().cast().getElementType(); @@ -2202,9 +2250,11 @@ PointwiseConverter, PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -151,65 +151,69 @@ // CHECK: cmpf %11 = "tosa.greater_equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> + // CHECK: linalg.generic + // CHECK: cmpf + %12 = "tosa.equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> + // CHECK: linalg.generic // CHECK: select - %12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %13 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: cmpf // CHECK: select - %13 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %14 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: cmpf // CHECK: select - %14 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %15 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: ceil - %15 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32> + %16 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: floor - %16 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32> + %17 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: cmpf // CHECK: select - %17 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32> + %18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: cmpf // CHECK: select - %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32> + %19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: negf // CHECK: exp // CHECK: addf // CHECK: divf - %19 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32> + %20 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: fptosi - %20 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32> + %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: constant 0 // CHECK: cmpf - %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1> + %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1> // CHECK: linalg.generic // CHECK: fptrunc - %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16> + %23 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16> // CHECK: linalg.generic // CHECK: yield - %23 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32> + %24 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: divf - %24 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32> + %25 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32> return } @@ -285,58 +289,76 @@ // CHECK: shift_right_unsigned %9 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: linalg.generic + // CHECK: shift_right_signed + %10 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 0 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: linalg.generic + // CHECK: constant 1 + // CHECK: constant 0 + // CHECK: constant true + // CHECK: cmpi + // CHECK: subi + // CHECK: shift_right_signed + // CHECK: trunci + // CHECK: and + // CHECK: and + // CHECK: zexti + // CHECK: addi + %11 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: linalg.generic // CHECK: cmpi - %10 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %12 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> // CHECK: linalg.generic // CHECK: cmpi - %11 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> // CHECK: linalg.generic // CHECK: select - %12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %14 = "tosa.select"(%12, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %13 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %15 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %14 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %16 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %15 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> + %17 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %16 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> + %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: trunci - %17 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16> + %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16> // CHECK: linalg.generic // CHECK: yield - %18 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32> + %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: sexti - %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64> + %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64> // CHECK: linalg.generic // CHECK: constant 0 // CHECK: cmpi - %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1> + %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1> // CHECK: linalg.generic // CHECK: sitofp - %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32> + %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32> return }