diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -74,8 +74,8 @@ [&](Type llvm1DVectorTy, ValueRange operands) { LLVM::ConstantOp zero = rewriter.create(loc, boolType, boolZero); - return rewriter.replaceOpWithNewOp(op, llvm1DVectorTy, - operands[0], zero); + return rewriter.create(loc, llvm1DVectorTy, operands[0], + zero); }, rewriter); } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -259,54 +259,7 @@ // tosa::ClzOp if (isa(op) && elementTy.isa()) { - int bitWidth = elementTy.getIntOrFloatBitWidth(); - auto zero = - rewriter.create(loc, IntegerAttr::get(elementTy, 0)); - auto leadingZeros = rewriter.create( - loc, IntegerAttr::get(elementTy, bitWidth)); - - SmallVector operands = {args[0], leadingZeros, zero}; - SmallVector types = {elementTy, elementTy, elementTy}; - SmallVector locations = {loc, loc, loc}; - - auto whileOp = rewriter.create(loc, types, operands); - Block *before = - rewriter.createBlock(&whileOp.getBefore(), {}, types, locations); - Block *after = - rewriter.createBlock(&whileOp.getAfter(), {}, types, locations); - - // The conditional block of the while loop. - { - rewriter.setInsertionPointToStart(&whileOp.getBefore().front()); - Value input = before->getArgument(0); - Value zero = before->getArgument(2); - - Value inputLargerThanZero = rewriter.create( - loc, arith::CmpIPredicate::ne, input, zero); - rewriter.create(loc, inputLargerThanZero, - before->getArguments()); - } - - // The body of the while loop: shift right until reaching a value of 0. - { - rewriter.setInsertionPointToStart(&whileOp.getAfter().front()); - Value input = after->getArgument(0); - Value leadingZeros = after->getArgument(1); - - auto one = rewriter.create( - loc, IntegerAttr::get(elementTy, 1)); - auto shifted = - rewriter.create(loc, resultTypes, input, one); - auto leadingZerosMinusOne = - rewriter.create(loc, resultTypes, leadingZeros, one); - - rewriter.create( - loc, - ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)})); - } - - rewriter.setInsertionPointAfter(whileOp); - return whileOp->getResult(1); + return rewriter.create(loc, elementTy, args[0]); } // tosa::LogicalAnd diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -366,12 +366,7 @@ // CHECK: arith.addi %12 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - // CHECK: scf.while - // CHECK: arith.cmpi ne - // CHECK: scf.condition - // CHECK: arith.shrui - // CHECK: arith.subi - // CHECK: scf.yield + // CHECK: math.ctlz %13 = "tosa.clz"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic