diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -29,7 +29,7 @@ PatternRewriter &rewriter) { Location loc = op->getLoc(); auto elementTy = - op->getResult(0).getType().cast().getElementType(); + op->getOperand(0).getType().cast().getElementType(); // tosa::AbsOp if (isa(op) && elementTy.isa()) @@ -66,6 +66,14 @@ if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); + // tosa::LogOp + if (isa(op) && elementTy.isa()) + return rewriter.create(loc, resultTypes, args); + + // tosa::ExpOp + if (isa(op) && elementTy.isa()) + return rewriter.create(loc, resultTypes, args); + // tosa::SubOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); @@ -77,6 +85,58 @@ if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); + // tosa::GreaterOp + if (isa(op) && elementTy.isa()) + return rewriter.create(loc, CmpFPredicate::OGT, args[0], + args[1]); + + if (isa(op) && elementTy.isSignlessInteger()) + return rewriter.create(loc, CmpIPredicate::sgt, args[0], + args[1]); + + // tosa::GreaterEqualOp + if (isa(op) && elementTy.isa()) + return rewriter.create(loc, CmpFPredicate::OGE, args[0], + args[1]); + + if (isa(op) && elementTy.isSignlessInteger()) + return rewriter.create(loc, CmpIPredicate::sge, args[0], + args[1]); + + // tosa::MaximumOp + if (isa(op) && elementTy.isa()) { + auto predicate = rewriter.create(loc, CmpFPredicate::OGT, + args[0], args[1]); + return rewriter.create(loc, predicate, args[0], args[1]); + } + + if (isa(op) && elementTy.isSignlessInteger()) { + auto predicate = rewriter.create(loc, CmpIPredicate::sgt, + args[0], args[1]); + return rewriter.create(loc, predicate, args[0], args[1]); + } + + // tosa::MinimumOp + if (isa(op) && elementTy.isa()) { + auto predicate = rewriter.create(loc, CmpFPredicate::OLT, + args[0], args[1]); + return rewriter.create(loc, predicate, args[0], args[1]); + } + + if (isa(op) && elementTy.isSignlessInteger()) { + auto predicate = rewriter.create(loc, CmpIPredicate::slt, + args[0], args[1]); + return rewriter.create(loc, predicate, args[0], args[1]); + } + + // tosa::CeilOp + if (isa(op) && elementTy.isa()) + return rewriter.create(loc, resultTypes, args); + + // tosa::FloorOp + if (isa(op) && elementTy.isa()) + return rewriter.create(loc, resultTypes, args); + rewriter.notifyMatchFailure( op, "unhandled op for linalg body calculation for elementwise op"); return nullptr; @@ -94,19 +154,21 @@ // For now require no broadcasting. Consider making it support broadcasting // operations. - Type uniqueTy = operation->getOperand(0).getType(); + Type uniqueInTy = operation->getOperand(0).getType(); bool allInputTypesEqual = llvm::all_of(operation->getOperandTypes(), - [&](Type operandTy) { return operandTy == uniqueTy; }); + [&](Type operandTy) { return operandTy == uniqueInTy; }); if (!allInputTypesEqual) return rewriter.notifyMatchFailure(operation, "All operands must have the same type"); - bool allResultTypesEqual = - llvm::all_of(operation->getResultTypes(), - [&](Type resultTy) { return resultTy == uniqueTy; }); - if (!allResultTypesEqual) + bool resultAndInputShapeEqual = + llvm::all_of(operation->getResults().getType(), [&](Type resultTy) { + return resultTy.cast().getShape() == t0.getShape(); + }); + + if (!resultAndInputShapeEqual) return rewriter.notifyMatchFailure( - operation, "All results must have the same type as the input"); + operation, "All results must have the same shape as the input"); // Construct the indexing maps needed for linalg.generic ops. SmallVector bodyArgTypes; @@ -179,10 +241,16 @@ MLIRContext *context, OwningRewritePatternList *patterns) { patterns->insert< PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, - PointwiseConverter>(context); + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter>( + context); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -100,6 +100,41 @@ // CHECK: linalg.generic // CHECK: pow %4 = "tosa.pow"(%1, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: linalg.generic + // CHECK: log + %5 = "tosa.log"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + + // CHECK: linalg.generic + // CHECK: exp + %6 = "tosa.exp"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + + // CHECK: linalg.generic + // CHECK: cmpf + %7 = "tosa.greater"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> + + // CHECK: linalg.generic + // CHECK: cmpf + %8 = "tosa.greater_equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> + + // CHECK: linalg.generic + // CHECK: cmpf + // CHECK: select + %9 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: linalg.generic + // CHECK: cmpf + // CHECK: select + %10 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: linalg.generic + // CHECK: ceil + %11 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32> + + // CHECK: linalg.generic + // CHECK: floor + %12 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32> + return } @@ -135,6 +170,25 @@ // CHECK: shift_right_unsigned %6 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: linalg.generic + // CHECK: cmpi + %7 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + + // CHECK: linalg.generic + // CHECK: cmpi + %8 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + + // CHECK: linalg.generic + // CHECK: cmpi + // CHECK: select + %9 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: linalg.generic + // CHECK: cmpi + // CHECK: select + %10 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + return }