diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -310,6 +311,55 @@ return rewriter.create(loc, resultTypes, result, extended); } + // tosa::ClzOp + if (isa(op) && elementTy.isa()) { + int bitWidth = elementTy.getIntOrFloatBitWidth(); + auto zero = + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + auto leadingZeros = rewriter.create( + loc, IntegerAttr::get(elementTy, bitWidth)); + + SmallVector operands = {args[0], leadingZeros, zero}; + SmallVector types = {elementTy, elementTy, elementTy}; + + auto whileOp = rewriter.create(loc, types, operands); + Block *before = rewriter.createBlock(&whileOp.before(), {}, types); + Block *after = rewriter.createBlock(&whileOp.after(), {}, types); + + // The conditional block of the while loop. + { + rewriter.setInsertionPointToStart(&whileOp.before().front()); + Value input = before->getArgument(0); + Value zero = before->getArgument(2); + + Value inputLargerThanZero = + rewriter.create(loc, CmpIPredicate::ne, input, zero); + rewriter.create(loc, inputLargerThanZero, + before->getArguments()); + } + + // The body of the while loop: shift right until reaching a value of 0. + { + rewriter.setInsertionPointToStart(&whileOp.after().front()); + Value input = after->getArgument(0); + Value leadingZeros = after->getArgument(1); + + auto one = rewriter.create( + loc, IntegerAttr::get(elementTy, 1)); + auto shifted = rewriter.create( + loc, resultTypes, input, one); + auto leadingZerosMinusOne = + rewriter.create(loc, resultTypes, leadingZeros, one); + + rewriter.create( + loc, + ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)})); + } + + rewriter.setInsertionPointAfter(whileOp); + return whileOp->getResult(1); + } + // tosa::LogicalAnd if (isa(op) && elementTy.isInteger(1)) return rewriter.create(loc, resultTypes, args); @@ -2905,6 +2955,7 @@ PointwiseConverter, PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -32,15 +33,16 @@ : public TosaToLinalgOnTensorsBase { public: void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } void runOnFunction() override { RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); target.addLegalDialect(); + tensor::TensorDialect, scf::SCFDialect>(); target.addIllegalDialect(); // Not every TOSA op can be legalized to linalg. diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -357,37 +357,45 @@ // CHECK: addi %12 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: scf.while + // CHECK: cmpi ne + // CHECK: scf.condition + // CHECK: shift_right_unsigned + // CHECK: subi + // CHECK: scf.yield + %13 = "tosa.clz"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: linalg.generic // CHECK: cmpi - %13 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %14 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> // CHECK: linalg.generic // CHECK: cmpi - %14 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %15 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> // CHECK: linalg.generic // CHECK: select - %15 = "tosa.select"(%13, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %16 = "tosa.select"(%14, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %16 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %17 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %17 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %18 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> + %19 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> + %20 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: constant -32768 @@ -397,27 +405,27 @@ // CHECK: cmpi slt // CHECK: select // CHECK: trunci - %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16> + %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16> // CHECK: linalg.generic // CHECK: sexti - %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64> + %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64> // CHECK: linalg.generic // CHECK: constant 0 // CHECK: cmpi - %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1> + %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1> // CHECK: linalg.generic // CHECK: sitofp - %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32> + %24 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: constant 0 // CHECK: cmpi sgt // CHECK: subi // CHECK: select - %24 = "tosa.abs"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + %25 = "tosa.abs"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> return } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6481,6 +6481,7 @@ ":LinalgOps", ":MathDialect", ":Pass", + ":SCFDialect", ":StandardOps", ":TensorDialect", ":TosaDialect",