diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -24,6 +24,28 @@ return SmallVector(nParallelLoops, getParallelIteratorTypeName()); } +template +static mlir::ConstantOp +createConstFromIntAttribute(Operation *op, std::string attrName, + Type requiredAttrType, PatternRewriter &rewriter) { + auto castedN = static_cast( + op->getAttr(attrName).cast().getValue().getSExtValue()); + return rewriter.create( + op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); +} + +template +static mlir::SelectOp clampHelper(Operation *op, ValueRange args, + mlir::ConstantOp min, mlir::ConstantOp max, + P pred, PatternRewriter &rewriter) { + Location loc = op->getLoc(); + auto smallerThanMin = rewriter.create(loc, pred, args[0], min); + auto minOrArg = + rewriter.create(loc, smallerThanMin, min, args[0]); + auto largerThanMax = rewriter.create(loc, pred, max, args[0]); + return rewriter.create(loc, largerThanMax, max, minOrArg); +} + static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef resultTypes, @@ -43,6 +65,42 @@ if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); + // tosa::SubOp + if (isa(op) && elementTy.isa()) + return rewriter.create(loc, resultTypes, args); + + if (isa(op) && elementTy.isa()) + return rewriter.create(loc, resultTypes, args); + + // tosa::MulOp + if (isa(op) && elementTy.isa()) { + if (dyn_cast(op).shift() != 0) { + (void)rewriter.notifyMatchFailure(op, + "Cannot have shift value for float"); + return nullptr; + } + return rewriter.create(loc, resultTypes, args); + } + + if (isa(op) && elementTy.isa()) { + auto mul = + rewriter.create(loc, resultTypes, args[0], args[1]); + auto constant = + rewriter.create(loc, elementTy, op->getAttr("shift")); + return rewriter.create(loc, resultTypes, mul, + constant); + } + + // tosa::NegateOp + if (isa(op) && elementTy.isa()) { + auto constant = + rewriter.create(loc, IntegerAttr::get(elementTy, -1)); + return rewriter.create(loc, resultTypes, args[0], constant); + } + + if (isa(op) && elementTy.isa()) + return rewriter.create(loc, resultTypes, args); + // tosa::BitwiseAndOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); @@ -67,6 +125,10 @@ if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); + // tosa::RsqrtOp + if (isa(op) && elementTy.isa()) + return rewriter.create(loc, resultTypes, args); + // tosa::LogOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); @@ -75,13 +137,6 @@ if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); - // tosa::SubOp - if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); - - if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); - // tosa::TanhOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); @@ -104,6 +159,13 @@ return rewriter.create(loc, CmpIPredicate::sge, args[0], args[1]); + // tosa::SelectOp + if (isa(op)) { + elementTy = op->getOperand(1).getType().cast().getElementType(); + if (elementTy.isa() || elementTy.isa()) + return rewriter.create(loc, args[0], args[1], args[2]); + } + // tosa::MaximumOp if (isa(op) && elementTy.isa()) { auto predicate = rewriter.create(loc, CmpFPredicate::OGT, @@ -138,6 +200,44 @@ if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); + // tosa::ClampOp + if (isa(op) && elementTy.isa()) { + auto min = rewriter.create(loc, elementTy, + op->getAttr("min_fp")); + auto max = rewriter.create(loc, elementTy, + op->getAttr("max_fp")); + return clampHelper(op, args, min, max, CmpFPredicate::OLT, + rewriter); + } + + if (isa(op) && elementTy.isa()) { + auto min = createConstFromIntAttribute(op, "min_int", elementTy, + rewriter); + auto max = createConstFromIntAttribute(op, "max_int", elementTy, + rewriter); + return clampHelper(op, args, min, max, CmpIPredicate::slt, + rewriter); + } + + // tosa::ReluNOp + if (isa(op) && elementTy.isa()) { + auto zero = + rewriter.create(loc, FloatAttr::get(elementTy, 0)); + auto n = rewriter.create(loc, elementTy, + op->getAttr("max_fp")); + return clampHelper(op, args, zero, n, CmpFPredicate::OLT, + rewriter); + } + + if (isa(op) && elementTy.isa()) { + auto zero = + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + auto n = createConstFromIntAttribute(op, "max_int", elementTy, + rewriter); + return clampHelper(op, args, zero, n, CmpIPredicate::slt, + rewriter); + } + (void)rewriter.notifyMatchFailure( op, "unhandled op for linalg body calculation for elementwise op"); return nullptr; @@ -245,16 +345,19 @@ MLIRContext *context, OwningRewritePatternList *patterns) { patterns->insert< PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, - PointwiseConverter, + PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter>( + PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter>( context); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -116,43 +116,69 @@ // CHECK: subf %3 = "tosa.sub"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // CHECK: linalg.generic + // CHECK: mulf + %4 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: linalg.generic + // CHECK: negf + %5 = "tosa.negate"(%0) : (tensor<1xf32>) -> tensor<1xf32> + // CHECK: linalg.generic // CHECK: pow - %4 = "tosa.pow"(%1, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %6 = "tosa.pow"(%1, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: linalg.generic + // CHECK: rsqrt + %7 = "tosa.rsqrt"(%1) : (tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: log - %5 = "tosa.log"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + %8 = "tosa.log"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: exp - %6 = "tosa.exp"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + %9 = "tosa.exp"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: cmpf - %7 = "tosa.greater"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> + %10 = "tosa.greater"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> // CHECK: linalg.generic // CHECK: cmpf - %8 = "tosa.greater_equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> + %11 = "tosa.greater_equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> + + // CHECK: linalg.generic + // CHECK: select + %12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: cmpf // CHECK: select - %9 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %13 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: cmpf // CHECK: select - %10 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %14 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: ceil - %11 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32> + %15 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: floor - %12 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32> + %16 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32> + + // CHECK: linalg.generic + // CHECK: cmpf + // CHECK: select + %17 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32> + + // CHECK: linalg.generic + // CHECK: cmpf + // CHECK: select + %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32> return } @@ -169,44 +195,65 @@ // CHECK: subi %1 = "tosa.sub"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: linalg.generic + // CHECK: muli + %2 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: linalg.generic + // CHECK: muli + %3 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: linalg.generic // CHECK: and - %2 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %4 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: or - %3 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %5 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: xor - %4 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %6 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: shift_left - %5 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %7 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: shift_right_unsigned - %6 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %8 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi - %7 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %9 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> // CHECK: linalg.generic // CHECK: cmpi - %8 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %10 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + + // CHECK: linalg.generic + // CHECK: select + %11 = "tosa.select"(%9, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %9 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %12 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %10 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %13 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: linalg.generic + // CHECK: cmpi + // CHECK: select + %14 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> + + // CHECK: linalg.generic + // CHECK: cmpi + // CHECK: select + %15 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> return }