diff --git a/mlir/benchmark/python/common.py b/mlir/benchmark/python/common.py --- a/mlir/benchmark/python/common.py +++ b/mlir/benchmark/python/common.py @@ -29,7 +29,7 @@ f"convert-scf-to-std," f"func-bufferize," f"arith-bufferize," - f"builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize)," + f"builtin.func(tensor-bufferize,finalizing-bufferize)," f"convert-vector-to-llvm" f"{{reassociate-fp-reductions=1 enable-index-optimizations=1}}," f"lower-affine," diff --git a/mlir/docs/Bufferization.md b/mlir/docs/Bufferization.md --- a/mlir/docs/Bufferization.md +++ b/mlir/docs/Bufferization.md @@ -87,7 +87,6 @@ pm.addNestedPass(createTCPBufferizePass()); // Bufferizes the downstream `tcp` dialect. pm.addNestedPass(createSCFBufferizePass()); pm.addNestedPass(createLinalgBufferizePass()); - pm.addNestedPass(createStdBufferizePass()); pm.addNestedPass(createTensorBufferizePass()); pm.addPass(createFuncBufferizePass()); diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h --- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h @@ -12,6 +12,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -1067,7 +1067,7 @@ %x = arith.cmpi "eq", %lhs, %rhs : vector<4xi64> // Generic form of the same operation. - %x = "std.arith.cmpi"(%lhs, %rhs) {predicate = 0 : i64} + %x = "arith.cmpi"(%lhs, %rhs) {predicate = 0 : i64} : (vector<4xi64>, vector<4xi64>) -> vector<4xi1> ``` }]; @@ -1140,4 +1140,55 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +def SelectOp : Arith_Op<"select", [ + AllTypesMatch<["true_value", "false_value", "result"]> + ] # ElementwiseMappable.traits> { + let summary = "select operation"; + let description = [{ + The `arith.select` operation chooses one value based on a binary condition + supplied as its first operand. If the value of the first operand is `1`, + the second operand is chosen, otherwise the third operand is chosen. + The second and the third operand must have the same type. + + The operation applies to vectors and tensors elementwise given the _shape_ + of all operands is identical. The choice is made for each element + individually based on the value at the same position as the element in the + condition operand. If an i1 is provided as the condition, the entire vector + or tensor is chosen. + + Example: + + ```mlir + // Custom form of scalar selection. + %x = arith.select %cond, %true, %false : i32 + + // Generic form of the same operation. + %x = "arith.select"(%cond, %true, %false) : (i1, i32, i32) -> i32 + + // Element-wise vector selection. + %vx = arith.select %vcond, %vtrue, %vfalse : vector<42xi1>, vector<42xf32> + + // Full vector selection. + %vx = arith.select %cond, %vtrue, %vfalse : vector<42xf32> + ``` + }]; + + let arguments = (ins BoolLike:$condition, + AnyType:$true_value, + AnyType:$false_value); + let results = (outs AnyType:$result); + + let hasCanonicalizer = 1; + let hasFolder = 1; + + // FIXME: Switch this to use the declarative assembly format. + let printer = [{ return ::print(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + #endif // ARITHMETIC_OPS diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -146,7 +146,7 @@ Note: This method can return multiple OpOperands, indicating that the given OpResult may at runtime alias with any of the OpOperands. This - is useful for branches and for ops such as `std.select`. + is useful for branches and for ops such as `arith.select`. }], /*retType=*/"SmallVector", /*methodName=*/"getAliasingOpOperand", diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -468,56 +468,6 @@ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } -//===----------------------------------------------------------------------===// -// SelectOp -//===----------------------------------------------------------------------===// - -def SelectOp : Std_Op<"select", [NoSideEffect, - AllTypesMatch<["true_value", "false_value", "result"]>, - DeclareOpInterfaceMethods] # - ElementwiseMappable.traits> { - let summary = "select operation"; - let description = [{ - The `select` operation chooses one value based on a binary condition - supplied as its first operand. If the value of the first operand is `1`, - the second operand is chosen, otherwise the third operand is chosen. - The second and the third operand must have the same type. - - The operation applies to vectors and tensors elementwise given the _shape_ - of all operands is identical. The choice is made for each element - individually based on the value at the same position as the element in the - condition operand. If an i1 is provided as the condition, the entire vector - or tensor is chosen. - - The `select` operation combined with [`cmpi`](#stdcmpi-cmpiop) can be used - to implement `min` and `max` with signed or unsigned comparison semantics. - - Example: - - ```mlir - // Custom form of scalar selection. - %x = select %cond, %true, %false : i32 - - // Generic form of the same operation. - %x = "std.select"(%cond, %true, %false) : (i1, i32, i32) -> i32 - - // Element-wise vector selection. - %vx = std.select %vcond, %vtrue, %vfalse : vector<42xi1>, vector<42xf32> - - // Full vector selection. - %vx = std.select %cond, %vtrue, %vfalse : vector<42xf32> - ``` - }]; - - let arguments = (ins BoolLike:$condition, - AnyType:$true_value, - AnyType:$false_value); - let results = (outs AnyType:$result); - - let hasCanonicalizer = 1; - let hasFolder = 1; -} - //===----------------------------------------------------------------------===// // SwitchOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h +++ /dev/null @@ -1,18 +0,0 @@ -//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H -#define MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H - -namespace mlir { -class DialectRegistry; - -void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); -} // namespace mlir - -#endif // MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -23,9 +23,6 @@ class RewritePatternSet; -/// Creates an instance of std bufferization pass. -std::unique_ptr createStdBufferizePass(); - /// Creates an instance of func bufferization pass. std::unique_ptr createFuncBufferizePass(); diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -11,11 +11,6 @@ include "mlir/Pass/PassBase.td" -def StdBufferize : Pass<"std-bufferize", "FuncOp"> { - let summary = "Bufferize the std dialect"; - let constructor = "mlir::createStdBufferizePass()"; -} - def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> { let summary = "Bufferize func/call/return ops"; let description = [{ diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h --- a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h @@ -30,13 +30,14 @@ // Takes the parameters for a clamp and turns it into a series of ops. template -mlir::SelectOp clampHelper(Location loc, Value arg, arith::ConstantOp min, - arith::ConstantOp max, P pred, OpBuilder &rewriter) { +arith::SelectOp clampHelper(Location loc, Value arg, arith::ConstantOp min, + arith::ConstantOp max, P pred, + OpBuilder &rewriter) { auto smallerThanMin = rewriter.create(loc, pred, arg, min); auto minOrArg = - rewriter.create(loc, smallerThanMin, min, arg); + rewriter.create(loc, smallerThanMin, min, arg); auto largerThanMax = rewriter.create(loc, pred, max, arg); - return rewriter.create(loc, largerThanMax, max, minOrArg); + return rewriter.create(loc, largerThanMax, max, minOrArg); } // Returns the values in an attribute as an array of values. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1368,14 +1368,14 @@ /// /// Example: /// ``` -/// %tensor_select = "std.select"(%pred_tensor, %true_val, %false_val) +/// %tensor_select = "arith.select"(%pred_tensor, %true_val, %false_val) /// : (tensor, tensor, tensor) /// -> tensor /// ``` /// can be scalarized to /// /// ``` -/// %scalar_select = "std.select"(%pred, %true_val_scalar, %false_val_scalar) +/// %scalar_select = "arith.select"(%pred, %true_val_scalar, %false_val_scalar) /// : (i1, f32, f32) -> f32 /// ``` template @@ -1430,12 +1430,12 @@ /// ``` /// /// ``` -/// %scalar_pred = "std.select"(%pred, %true_val, %false_val) +/// %scalar_pred = "arith.select"(%pred, %true_val, %false_val) /// : (i1, tensor, tensor) -> tensor /// ``` /// can be tensorized to /// ``` -/// %tensor_pred = "std.select"(%pred, %true_val, %false_val) +/// %tensor_pred = "arith.select"(%pred, %true_val, %false_val) /// : (tensor, tensor, tensor) /// -> tensor /// ``` diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -96,8 +96,8 @@ loc, arith::CmpIPredicate::slt, remainder, zeroCst); Value correctedRemainder = builder.create(loc, remainder, rhs); - Value result = builder.create(loc, isRemainderNegative, - correctedRemainder, remainder); + Value result = builder.create( + loc, isRemainderNegative, correctedRemainder, remainder); return result; } @@ -134,12 +134,12 @@ loc, arith::CmpIPredicate::slt, lhs, zeroCst); Value negatedDecremented = builder.create(loc, noneCst, lhs); Value dividend = - builder.create(loc, negative, negatedDecremented, lhs); + builder.create(loc, negative, negatedDecremented, lhs); Value quotient = builder.create(loc, dividend, rhs); Value correctedQuotient = builder.create(loc, noneCst, quotient); - Value result = - builder.create(loc, negative, correctedQuotient, quotient); + Value result = builder.create(loc, negative, + correctedQuotient, quotient); return result; } @@ -175,14 +175,14 @@ Value negated = builder.create(loc, zeroCst, lhs); Value decremented = builder.create(loc, lhs, oneCst); Value dividend = - builder.create(loc, nonPositive, negated, decremented); + builder.create(loc, nonPositive, negated, decremented); Value quotient = builder.create(loc, dividend, rhs); Value negatedQuotient = builder.create(loc, zeroCst, quotient); Value incrementedQuotient = builder.create(loc, quotient, oneCst); - Value result = builder.create(loc, nonPositive, negatedQuotient, - incrementedQuotient); + Value result = builder.create( + loc, nonPositive, negatedQuotient, incrementedQuotient); return result; } @@ -259,7 +259,8 @@ Value value = *valueIt++; for (; valueIt != values.end(); ++valueIt) { auto cmpOp = builder.create(loc, predicate, value, *valueIt); - value = builder.create(loc, cmpOp.getResult(), value, *valueIt); + value = builder.create(loc, cmpOp.getResult(), value, + *valueIt); } return value; diff --git a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp --- a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp +++ b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp @@ -66,6 +66,8 @@ VectorConvertToLLVMPattern; using BitcastOpLowering = VectorConvertToLLVMPattern; +using SelectOpLowering = + VectorConvertToLLVMPattern; //===----------------------------------------------------------------------===// // Op Lowering Patterns @@ -292,7 +294,8 @@ IndexCastOpLowering, BitcastOpLowering, CmpIOpLowering, - CmpFOpLowering + CmpFOpLowering, + SelectOpLowering >(converter); // clang-format on } diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -190,6 +190,15 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts arith.select to spv.Select. +class SelectOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -780,6 +789,19 @@ return success(); } +//===----------------------------------------------------------------------===// +// SelectOpPattern +//===----------------------------------------------------------------------===// + +LogicalResult +SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(op, adaptor.getCondition(), + adaptor.getTrueValue(), + adaptor.getFalseValue()); + return success(); +} + //===----------------------------------------------------------------------===// // Pattern Population //===----------------------------------------------------------------------===// @@ -820,7 +842,8 @@ TypeCastingOpPattern, TypeCastingOpPattern, CmpIOpBooleanPattern, CmpIOpPattern, - CmpFOpNanNonePattern, CmpFOpPattern + CmpFOpNanNonePattern, CmpFOpPattern, + SelectOpPattern >(typeConverter, patterns.getContext()); // clang-format on diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -227,10 +227,10 @@ Value one = rewriter.create( loc, elementType, rewriter.getFloatAttr(elementType, 1)); Value lhsRealIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, lhsRealInfinite, one, zero), + loc, rewriter.create(loc, lhsRealInfinite, one, zero), lhsReal); Value lhsImagIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, lhsImagInfinite, one, zero), + loc, rewriter.create(loc, lhsImagInfinite, one, zero), lhsImag); Value lhsRealIsInfWithSignTimesRhsReal = rewriter.create(loc, lhsRealIsInfWithSign, rhsReal); @@ -265,10 +265,10 @@ Value finiteNumInfiniteDenom = rewriter.create(loc, lhsFinite, rhsInfinite); Value rhsRealIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, rhsRealInfinite, one, zero), + loc, rewriter.create(loc, rhsRealInfinite, one, zero), rhsReal); Value rhsImagIsInfWithSign = rewriter.create( - loc, rewriter.create(loc, rhsImagInfinite, one, zero), + loc, rewriter.create(loc, rhsImagInfinite, one, zero), rhsImag); Value rhsRealIsInfWithSignTimesLhsReal = rewriter.create(loc, lhsReal, rhsRealIsInfWithSign); @@ -289,21 +289,21 @@ Value realAbsSmallerThanImagAbs = rewriter.create( loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); - Value resultReal = rewriter.create(loc, realAbsSmallerThanImagAbs, - resultReal1, resultReal2); - Value resultImag = rewriter.create(loc, realAbsSmallerThanImagAbs, - resultImag1, resultImag2); - Value resultRealSpecialCase3 = rewriter.create( + Value resultReal = rewriter.create( + loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); + Value resultImag = rewriter.create( + loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); + Value resultRealSpecialCase3 = rewriter.create( loc, finiteNumInfiniteDenom, resultReal4, resultReal); - Value resultImagSpecialCase3 = rewriter.create( + Value resultImagSpecialCase3 = rewriter.create( loc, finiteNumInfiniteDenom, resultImag4, resultImag); - Value resultRealSpecialCase2 = rewriter.create( + Value resultRealSpecialCase2 = rewriter.create( loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); - Value resultImagSpecialCase2 = rewriter.create( + Value resultImagSpecialCase2 = rewriter.create( loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); - Value resultRealSpecialCase1 = rewriter.create( + Value resultRealSpecialCase1 = rewriter.create( loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); - Value resultImagSpecialCase1 = rewriter.create( + Value resultImagSpecialCase1 = rewriter.create( loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); Value resultRealIsNaN = rewriter.create( @@ -312,9 +312,9 @@ loc, arith::CmpFPredicate::UNO, resultImag, zero); Value resultIsNaN = rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); - Value resultRealWithSpecialCases = rewriter.create( + Value resultRealWithSpecialCases = rewriter.create( loc, resultIsNaN, resultRealSpecialCase1, resultReal); - Value resultImagWithSpecialCases = rewriter.create( + Value resultImagWithSpecialCases = rewriter.create( loc, resultIsNaN, resultImagSpecialCase1, resultImag); rewriter.replaceOpWithNewOp( @@ -450,24 +450,26 @@ b.create(elementType, b.getZeroAttr(elementType)); Value one = b.create(elementType, b.getFloatAttr(elementType, 1)); - Value lhsRealIsInfFloat = b.create(lhsRealIsInf, one, zero); - lhsReal = b.create( + Value lhsRealIsInfFloat = + b.create(lhsRealIsInf, one, zero); + lhsReal = b.create( lhsIsInf, b.create(lhsRealIsInfFloat, lhsReal), lhsReal); - Value lhsImagIsInfFloat = b.create(lhsImagIsInf, one, zero); - lhsImag = b.create( + Value lhsImagIsInfFloat = + b.create(lhsImagIsInf, one, zero); + lhsImag = b.create( lhsIsInf, b.create(lhsImagIsInfFloat, lhsImag), lhsImag); Value lhsIsInfAndRhsRealIsNan = b.create(lhsIsInf, rhsRealIsNan); - rhsReal = - b.create(lhsIsInfAndRhsRealIsNan, - b.create(zero, rhsReal), rhsReal); + rhsReal = b.create( + lhsIsInfAndRhsRealIsNan, b.create(zero, rhsReal), + rhsReal); Value lhsIsInfAndRhsImagIsNan = b.create(lhsIsInf, rhsImagIsNan); - rhsImag = - b.create(lhsIsInfAndRhsImagIsNan, - b.create(zero, rhsImag), rhsImag); + rhsImag = b.create( + lhsIsInfAndRhsImagIsNan, b.create(zero, rhsImag), + rhsImag); // Case 2. `rhsReal` or `rhsImag` are infinite. Value rhsRealIsInf = @@ -479,24 +481,26 @@ b.create(arith::CmpFPredicate::UNO, lhsReal, lhsReal); Value lhsImagIsNan = b.create(arith::CmpFPredicate::UNO, lhsImag, lhsImag); - Value rhsRealIsInfFloat = b.create(rhsRealIsInf, one, zero); - rhsReal = b.create( + Value rhsRealIsInfFloat = + b.create(rhsRealIsInf, one, zero); + rhsReal = b.create( rhsIsInf, b.create(rhsRealIsInfFloat, rhsReal), rhsReal); - Value rhsImagIsInfFloat = b.create(rhsImagIsInf, one, zero); - rhsImag = b.create( + Value rhsImagIsInfFloat = + b.create(rhsImagIsInf, one, zero); + rhsImag = b.create( rhsIsInf, b.create(rhsImagIsInfFloat, rhsImag), rhsImag); Value rhsIsInfAndLhsRealIsNan = b.create(rhsIsInf, lhsRealIsNan); - lhsReal = - b.create(rhsIsInfAndLhsRealIsNan, - b.create(zero, lhsReal), lhsReal); + lhsReal = b.create( + rhsIsInfAndLhsRealIsNan, b.create(zero, lhsReal), + lhsReal); Value rhsIsInfAndLhsImagIsNan = b.create(rhsIsInf, lhsImagIsNan); - lhsImag = - b.create(rhsIsInfAndLhsImagIsNan, - b.create(zero, lhsImag), lhsImag); + lhsImag = b.create( + rhsIsInfAndLhsImagIsNan, b.create(zero, lhsImag), + lhsImag); Value recalc = b.create(lhsIsInf, rhsIsInf); // Case 3. One of the pairwise products of left hand side with right hand @@ -522,24 +526,24 @@ isSpecialCase = b.create(isSpecialCase, notRecalc); Value isSpecialCaseAndLhsRealIsNan = b.create(isSpecialCase, lhsRealIsNan); - lhsReal = - b.create(isSpecialCaseAndLhsRealIsNan, - b.create(zero, lhsReal), lhsReal); + lhsReal = b.create( + isSpecialCaseAndLhsRealIsNan, b.create(zero, lhsReal), + lhsReal); Value isSpecialCaseAndLhsImagIsNan = b.create(isSpecialCase, lhsImagIsNan); - lhsImag = - b.create(isSpecialCaseAndLhsImagIsNan, - b.create(zero, lhsImag), lhsImag); + lhsImag = b.create( + isSpecialCaseAndLhsImagIsNan, b.create(zero, lhsImag), + lhsImag); Value isSpecialCaseAndRhsRealIsNan = b.create(isSpecialCase, rhsRealIsNan); - rhsReal = - b.create(isSpecialCaseAndRhsRealIsNan, - b.create(zero, rhsReal), rhsReal); + rhsReal = b.create( + isSpecialCaseAndRhsRealIsNan, b.create(zero, rhsReal), + rhsReal); Value isSpecialCaseAndRhsImagIsNan = b.create(isSpecialCase, rhsImagIsNan); - rhsImag = - b.create(isSpecialCaseAndRhsImagIsNan, - b.create(zero, rhsImag), rhsImag); + rhsImag = b.create( + isSpecialCaseAndRhsImagIsNan, b.create(zero, rhsImag), + rhsImag); recalc = b.create(recalc, isSpecialCase); recalc = b.create(isNan, recalc); @@ -548,16 +552,16 @@ lhsImagTimesRhsImag = b.create(lhsImag, rhsImag); Value newReal = b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag); - real = - b.create(recalc, b.create(inf, newReal), real); + real = b.create( + recalc, b.create(inf, newReal), real); // Recalculate imag part. lhsImagTimesRhsReal = b.create(lhsImag, rhsReal); lhsRealTimesRhsImag = b.create(lhsReal, rhsImag); Value newImag = b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag); - imag = - b.create(recalc, b.create(inf, newImag), imag); + imag = b.create( + recalc, b.create(inf, newImag), imag); rewriter.replaceOpWithNewOp(op, type, real, imag); return success(); @@ -608,8 +612,8 @@ Value realSign = b.create(real, abs); Value imagSign = b.create(imag, abs); Value sign = b.create(type, realSign, imagSign); - rewriter.replaceOpWithNewOp(op, isZero, adaptor.getComplex(), - sign); + rewriter.replaceOpWithNewOp(op, isZero, + adaptor.getComplex(), sign); return success(); } }; diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -71,8 +71,9 @@ static bool matchSelectReduction(Block &block, ArrayRef lessThanPredicates, ArrayRef greaterThanPredicates, bool &isMin) { - static_assert(llvm::is_one_of::value, - "only std and llvm select ops are supported"); + static_assert( + llvm::is_one_of::value, + "only arithmetic and llvm select ops are supported"); // Expect exactly three operations in the block. if (block.empty() || llvm::hasSingleElement(block) || @@ -290,7 +291,7 @@ // Match select-based min/max reductions. bool isMin; - if (matchSelectReduction( + if (matchSelectReduction( reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) || matchSelectReduction( @@ -299,7 +300,7 @@ return createDecl(builder, symbolTable, reduce, minMaxValueForFloat(type, !isMin)); } - if (matchSelectReduction( + if (matchSelectReduction( reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) || matchSelectReduction( @@ -311,7 +312,7 @@ isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max, decl, reduce); } - if (matchSelectReduction( + if (matchSelectReduction( reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) || matchSelectReduction( diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -107,8 +107,8 @@ Value dimIsOne = b.create(loc, arith::CmpIPredicate::eq, lesserRankOperandExtent, one); - Value dim = b.create(loc, dimIsOne, broadcastedDim, - lesserRankOperandExtent); + Value dim = b.create( + loc, dimIsOne, broadcastedDim, lesserRankOperandExtent); b.create(loc, dim); }) .getResult(0); @@ -144,7 +144,7 @@ for (Value v : llvm::drop_begin(ranks, 1)) { Value rankIsGreater = lb.create(arith::CmpIPredicate::ugt, v, maxRank); - maxRank = lb.create(rankIsGreater, v, maxRank); + maxRank = lb.create(rankIsGreater, v, maxRank); } // Calculate the difference of ranks and the maximum rank for later offsets. @@ -259,7 +259,7 @@ for (Value v : llvm::drop_begin(ranks, 1)) { Value rankIsGreater = lb.create(arith::CmpIPredicate::ugt, v, maxRank); - maxRank = lb.create(rankIsGreater, v, maxRank); + maxRank = lb.create(rankIsGreater, v, maxRank); } // Calculate the difference of ranks and the maximum rank for later offsets. @@ -619,7 +619,7 @@ Value add = b.create(originalIndex, rank); Value indexIsNegative = b.create(arith::CmpIPredicate::slt, originalIndex, zero); - Value index = b.create(indexIsNegative, add, originalIndex); + Value index = b.create(indexIsNegative, add, originalIndex); Value one = b.create(1); Value head = diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -387,9 +387,6 @@ } }; -// Straightforward lowerings. -using SelectOpLowering = VectorConvertToLLVMPattern; - /// Lower `std.assert`. The default lowering calls the `abort` function if the /// assertion is violated and has no effect otherwise. The failure message is /// ignored by the default lowering but should be propagated by any custom @@ -685,7 +682,6 @@ CondBranchOpLowering, ConstantOpLowering, ReturnOpLowering, - SelectOpLowering, SwitchOpLowering>(converter); // clang-format on } diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -46,15 +46,6 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts std.select to spv.Select. -class SelectOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(SelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - /// Converts std.br to spv.Branch. struct BranchOpPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -155,19 +146,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// SelectOp -//===----------------------------------------------------------------------===// - -LogicalResult -SelectOpPattern::matchAndRewrite(SelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - rewriter.replaceOpWithNewOp(op, adaptor.getCondition(), - adaptor.getTrueValue(), - adaptor.getFalseValue()); - return success(); -} - //===----------------------------------------------------------------------===// // BranchOpPattern //===----------------------------------------------------------------------===// @@ -211,8 +189,8 @@ spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, - ReturnOpPattern, SelectOpPattern, BranchOpPattern, CondBranchOpPattern>( - typeConverter, context); + ReturnOpPattern, BranchOpPattern, CondBranchOpPattern>(typeConverter, + context); } void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -59,7 +59,7 @@ auto cmp = rewriter.create(loc, arith::CmpIPredicate::sgt, args[0], zero); auto neg = rewriter.create(loc, zero, args[0]); - return rewriter.create(loc, cmp, args[0], neg); + return rewriter.create(loc, cmp, args[0], neg); } // tosa::AddOp @@ -380,33 +380,33 @@ if (isa(op)) { elementTy = op->getOperand(1).getType().cast().getElementType(); if (elementTy.isa() || elementTy.isa()) - return rewriter.create(loc, args[0], args[1], args[2]); + return rewriter.create(loc, args[0], args[1], args[2]); } // tosa::MaximumOp if (isa(op) && elementTy.isa()) { auto predicate = rewriter.create( loc, arith::CmpFPredicate::OGT, args[0], args[1]); - return rewriter.create(loc, predicate, args[0], args[1]); + return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isSignlessInteger()) { auto predicate = rewriter.create( loc, arith::CmpIPredicate::sgt, args[0], args[1]); - return rewriter.create(loc, predicate, args[0], args[1]); + return rewriter.create(loc, predicate, args[0], args[1]); } // tosa::MinimumOp if (isa(op) && elementTy.isa()) { auto predicate = rewriter.create( loc, arith::CmpFPredicate::OLT, args[0], args[1]); - return rewriter.create(loc, predicate, args[0], args[1]); + return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isSignlessInteger()) { auto predicate = rewriter.create( loc, arith::CmpIPredicate::slt, args[0], args[1]); - return rewriter.create(loc, predicate, args[0], args[1]); + return rewriter.create(loc, predicate, args[0], args[1]); } // tosa::CeilOp @@ -558,7 +558,7 @@ auto negative = rewriter.create( loc, arith::CmpFPredicate::OLT, args[0], zero); auto rounded = - rewriter.create(loc, negative, subbed, added); + rewriter.create(loc, negative, subbed, added); auto clamped = clampHelper( loc, rounded, intMin, intMax, arith::CmpFPredicate::OLT, rewriter); @@ -792,25 +792,25 @@ if (isa(op) && elementTy.isa()) { auto predicate = rewriter.create( loc, arith::CmpFPredicate::OLT, args[0], args[1]); - return rewriter.create(loc, predicate, args[0], args[1]); + return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isa()) { auto predicate = rewriter.create( loc, arith::CmpIPredicate::slt, args[0], args[1]); - return rewriter.create(loc, predicate, args[0], args[1]); + return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isa()) { auto predicate = rewriter.create( loc, arith::CmpFPredicate::OGT, args[0], args[1]); - return rewriter.create(loc, predicate, args[0], args[1]); + return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isa()) { auto predicate = rewriter.create( loc, arith::CmpIPredicate::sgt, args[0], args[1]); - return rewriter.create(loc, predicate, args[0], args[1]); + return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isInteger(1)) @@ -1525,9 +1525,9 @@ loc, rewriter.getI32IntegerAttr(1)); auto yOffset = - rewriter.create(loc, yPred, oneVal, zeroVal); + rewriter.create(loc, yPred, oneVal, zeroVal); auto xOffset = - rewriter.create(loc, xPred, oneVal, zeroVal); + rewriter.create(loc, xPred, oneVal, zeroVal); iy = rewriter.create(loc, iy, yOffset); ix = rewriter.create(loc, ix, xOffset); @@ -2052,9 +2052,9 @@ return; } - auto resultMax = rewriter.create(nestedLoc, predicate, - newValue, oldValue); - auto resultIndex = rewriter.create( + auto resultMax = rewriter.create( + nestedLoc, predicate, newValue, oldValue); + auto resultIndex = rewriter.create( nestedLoc, predicate, newIndex, oldIndex); nestedBuilder.create( nestedLoc, ValueRange({resultIndex, resultMax})); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -748,7 +748,7 @@ Value cmp = rewriter.create( loc, arith::CmpIPredicate::slt, dx, zero); - Value offset = rewriter.create(loc, cmp, dx, zero); + Value offset = rewriter.create(loc, cmp, dx, zero); return rewriter.create(loc, v, offset)->getResult(0); }; @@ -758,7 +758,7 @@ auto kH2 = padFn(kH1, y1, pad[3]); auto kHCmp = rewriter.create( loc, arith::CmpIPredicate::slt, kH2, one); - auto kH3 = rewriter.create(loc, kHCmp, one, kH2); + auto kH3 = rewriter.create(loc, kHCmp, one, kH2); // compute the horizontal component of coverage. auto kW0 = rewriter.create(loc, kernel[1]); @@ -766,7 +766,7 @@ auto kW2 = padFn(kW1, x1, pad[5]); auto kWCmp = rewriter.create( loc, arith::CmpIPredicate::slt, kW2, one); - auto kW3 = rewriter.create(loc, kWCmp, one, kW2); + auto kW3 = rewriter.create(loc, kWCmp, one, kW2); // Compute the total number of elements and normalize. Value count = rewriter.create(loc, kH3, kW3); diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp @@ -135,7 +135,7 @@ Value valueGreaterThanZero = rewriter.create( loc, arith::CmpIPredicate::sge, value32, zero32); - Value doubleRound64 = rewriter.create( + Value doubleRound64 = rewriter.create( loc, valueGreaterThanZero, roundAdd64, roundSub64); // We only perform double rounding if the shift value is greater than 32. @@ -143,8 +143,8 @@ loc, getConstantAttr(i32Ty, 32, rewriter)); Value shiftGreaterThanThirtyTwo = rewriter.create( loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); - round64 = rewriter.create(loc, shiftGreaterThanThirtyTwo, - doubleRound64, round64); + round64 = rewriter.create(loc, shiftGreaterThanThirtyTwo, + doubleRound64, round64); } // The computation below equates to the following pseudocode: diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -1436,8 +1436,8 @@ for (unsigned i = 0; i < newYieldOp->getNumOperands(); ++i) { Value result = newYieldOp->getOperand(i); Value iterArg = cast(newParentOp).getRegionIterArgs()[i]; - Value maskedResult = state.builder.create(result.getLoc(), mask, - result, iterArg); + Value maskedResult = state.builder.create( + result.getLoc(), mask, result, iterArg); LLVM_DEBUG( dbgs() << "\n[early-vect]+++++ masking a yielded vector value: " << maskedResult); diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -1378,6 +1378,173 @@ return BoolAttr::get(getContext(), val); } +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +// Transforms a select of a boolean to arithmetic operations +// +// arith.select %arg, %x, %y : i1 +// +// becomes +// +// and(%arg, %x) or and(!%arg, %y) +struct SelectI1Simplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::SelectOp op, + PatternRewriter &rewriter) const override { + if (!op.getType().isInteger(1)) + return failure(); + + Value falseConstant = + rewriter.create(op.getLoc(), true, 1); + Value notCondition = rewriter.create( + op.getLoc(), op.getCondition(), falseConstant); + + Value trueVal = rewriter.create( + op.getLoc(), op.getCondition(), op.getTrueValue()); + Value falseVal = rewriter.create(op.getLoc(), notCondition, + op.getFalseValue()); + rewriter.replaceOpWithNewOp(op, trueVal, falseVal); + return success(); + } +}; + +// select %arg, %c1, %c0 => extui %arg +struct SelectToExtUI : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::SelectOp op, + PatternRewriter &rewriter) const override { + // Cannot extui i1 to i1, or i1 to f32 + if (!op.getType().isa() || op.getType().isInteger(1)) + return failure(); + + // select %x, c1, %c0 => extui %arg + if (matchPattern(op.getTrueValue(), m_One())) + if (matchPattern(op.getFalseValue(), m_Zero())) { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getCondition()); + return success(); + } + + // select %x, c0, %c1 => extui (xor %arg, true) + if (matchPattern(op.getTrueValue(), m_Zero())) + if (matchPattern(op.getFalseValue(), m_One())) { + rewriter.replaceOpWithNewOp( + op, op.getType(), + rewriter.create( + op.getLoc(), op.getCondition(), + rewriter.create( + op.getLoc(), 1, op.getCondition().getType()))); + return success(); + } + + return failure(); + } +}; + +void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert(context); +} + +OpFoldResult arith::SelectOp::fold(ArrayRef operands) { + Value trueVal = getTrueValue(); + Value falseVal = getFalseValue(); + if (trueVal == falseVal) + return trueVal; + + Value condition = getCondition(); + + // select true, %0, %1 => %0 + if (matchPattern(condition, m_One())) + return trueVal; + + // select false, %0, %1 => %1 + if (matchPattern(condition, m_Zero())) + return falseVal; + + // select %x, true, false => %x + if (getType().isInteger(1)) + if (matchPattern(getTrueValue(), m_One())) + if (matchPattern(getFalseValue(), m_Zero())) + return condition; + + if (auto cmp = dyn_cast_or_null(condition.getDefiningOp())) { + auto pred = cmp.getPredicate(); + if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { + auto cmpLhs = cmp.getLhs(); + auto cmpRhs = cmp.getRhs(); + + // %0 = arith.cmpi eq, %arg0, %arg1 + // %1 = arith.select %0, %arg0, %arg1 => %arg1 + + // %0 = arith.cmpi ne, %arg0, %arg1 + // %1 = arith.select %0, %arg0, %arg1 => %arg0 + + if ((cmpLhs == trueVal && cmpRhs == falseVal) || + (cmpRhs == trueVal && cmpLhs == falseVal)) + return pred == arith::CmpIPredicate::ne ? trueVal : falseVal; + } + } + return nullptr; +} + +static void print(OpAsmPrinter &p, arith::SelectOp op) { + p << " " << op.getOperands(); + p.printOptionalAttrDict(op->getAttrs()); + p << " : "; + if (ShapedType condType = op.getCondition().getType().dyn_cast()) + p << condType << ", "; + p << op.getType(); +} + +static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) { + Type conditionType, resultType; + SmallVector operands; + if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(resultType)) + return failure(); + + // Check for the explicit condition type if this is a masked tensor or vector. + if (succeeded(parser.parseOptionalComma())) { + conditionType = resultType; + if (parser.parseType(resultType)) + return failure(); + } else { + conditionType = parser.getBuilder().getI1Type(); + } + + result.addTypes(resultType); + return parser.resolveOperands(operands, + {conditionType, resultType, resultType}, + parser.getNameLoc(), result.operands); +} + +static LogicalResult verify(arith::SelectOp op) { + Type conditionType = op.getCondition().getType(); + if (conditionType.isSignlessInteger(1)) + return success(); + + // If the result type is a vector or tensor, the type can be a mask with the + // same elements. + Type resultType = op.getType(); + if (!resultType.isa()) + return op.emitOpError() + << "expected condition to be a signless i1, but got " + << conditionType; + Type shapedConditionType = getI1SameShape(resultType); + if (conditionType != shapedConditionType) + return op.emitOpError() + << "expected condition type to have the same shape " + "as the result type, expected " + << shapedConditionType << ", but got " << conditionType; + return success(); +} + //===----------------------------------------------------------------------===// // Atomic Enum //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt @@ -14,5 +14,6 @@ LINK_LIBS PUBLIC MLIRDialect + MLIRInferTypeOpInterface MLIRIR ) diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp @@ -14,12 +14,10 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +using namespace mlir; using namespace mlir::bufferization; -namespace mlir { -namespace arith { namespace { - /// Bufferization of arith.constant. Replace with memref.get_global. struct ConstantOpInterface : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return false; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return op->getOpResult(0) /*result*/; + } + + SmallVector + getAliasingOpOperand(Operation *op, OpResult opResult, + const BufferizationState &state) const { + return {&op->getOpOperand(1) /*true_value*/, + &op->getOpOperand(2) /*false_value*/}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationState &state) const { + auto selectOp = cast(op); + + // `getBuffer` introduces copies if an OpOperand bufferizes out-of-place. + // TODO: It would be more efficient to copy the result of the `select` op + // instead of its OpOperands. In the worst case, 2 copies are inserted at + // the moment (one for each tensor). When copying the op result, only one + // copy would be needed. + Value trueBuffer = + *state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/); + Value falseBuffer = + *state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/); + replaceOpWithNewBufferizedOp( + rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer); + return success(); + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationState &state) const { + return BufferRelation::None; + } +}; + } // namespace -} // namespace arith -} // namespace mlir void mlir::arith::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addOpInterface(); registry.addOpInterface(); + registry.addOpInterface(); } diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp @@ -40,7 +40,7 @@ Value minusOne = rewriter.create(loc, a, one); Value quotient = rewriter.create(loc, minusOne, b); Value plusOne = rewriter.create(loc, quotient, one); - rewriter.replaceOpWithNewOp(op, compare, zero, plusOne); + rewriter.replaceOpWithNewOp(op, compare, zero, plusOne); return success(); } }; @@ -62,7 +62,7 @@ // Compute x = (b>0) ? -1 : 1. Value compare = rewriter.create(loc, arith::CmpIPredicate::sgt, b, zero); - Value x = rewriter.create(loc, compare, minusOne, plusOne); + Value x = rewriter.create(loc, compare, minusOne, plusOne); // Compute positive res: 1 + ((x+a)/b). Value xPlusA = rewriter.create(loc, x, a); Value xPlusADivB = rewriter.create(loc, xPlusA, b); @@ -91,7 +91,8 @@ Value compareRes = rewriter.create(loc, firstTerm, secondTerm); // Perform substitution and return success. - rewriter.replaceOpWithNewOp(op, compareRes, posRes, negRes); + rewriter.replaceOpWithNewOp(op, compareRes, posRes, + negRes); return success(); } }; @@ -113,7 +114,7 @@ // Compute x = (b<0) ? 1 : -1. Value compare = rewriter.create(loc, arith::CmpIPredicate::slt, b, zero); - Value x = rewriter.create(loc, compare, plusOne, minusOne); + Value x = rewriter.create(loc, compare, plusOne, minusOne); // Compute negative res: -1 - ((x-a)/b). Value xMinusA = rewriter.create(loc, x, a); Value xMinusADivB = rewriter.create(loc, xMinusA, b); @@ -140,7 +141,8 @@ Value compareRes = rewriter.create(loc, firstTerm, secondTerm); // Perform substitution and return success. - rewriter.replaceOpWithNewOp(op, compareRes, negRes, posRes); + rewriter.replaceOpWithNewOp(op, compareRes, negRes, + posRes); return success(); } }; @@ -161,12 +163,12 @@ pred == arith::CmpFPredicate::ULT, "pred must be either UGT or ULT"); Value cmp = rewriter.create(loc, pred, lhs, rhs); - Value select = rewriter.create(loc, cmp, lhs, rhs); + Value select = rewriter.create(loc, cmp, lhs, rhs); // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'. Value isNaN = rewriter.create(loc, arith::CmpFPredicate::UNO, rhs, rhs); - rewriter.replaceOpWithNewOp(op, isNaN, rhs, select); + rewriter.replaceOpWithNewOp(op, isNaN, rhs, select); return success(); } }; @@ -182,7 +184,7 @@ Location loc = op.getLoc(); Value cmp = rewriter.create(loc, pred, lhs, rhs); - rewriter.replaceOpWithNewOp(op, cmp, lhs, rhs); + rewriter.replaceOpWithNewOp(op, cmp, lhs, rhs); return success(); } }; diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -404,12 +404,12 @@ } else { // Select nested loop lower/upper bounds depending on our position in // the multi-dimensional iteration space. - auto lb = nb.create(isBlockFirstCoord[loopIdx], - blockFirstCoord[loopIdx + 1], c0); + auto lb = nb.create( + isBlockFirstCoord[loopIdx], blockFirstCoord[loopIdx + 1], c0); - auto ub = nb.create(isBlockLastCoord[loopIdx], - blockEndCoord[loopIdx + 1], - tripCounts[loopIdx + 1]); + auto ub = nb.create(isBlockLastCoord[loopIdx], + blockEndCoord[loopIdx + 1], + tripCounts[loopIdx + 1]); nb.create(lb, ub, c1, ValueRange(), workLoopBuilder(loopIdx + 1)); @@ -831,8 +831,8 @@ arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin); Value bracketScalingFactor = b.create( llvm::APFloat(p.second), b.getF32Type()); - scalingFactor = - b.create(inBracket, bracketScalingFactor, scalingFactor); + scalingFactor = b.create(inBracket, bracketScalingFactor, + scalingFactor); } Value numWorkersIndex = b.create(numWorkerThreadsVal, b.getI32Type()); diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -258,7 +258,7 @@ AccumulatorFactory getCmpFactory() const { return [&](Value lhs, Value rhs) { Value cmp = rewriter.create(loc, predicate, lhs, rhs); - return rewriter.create(loc, cmp, lhs, rhs); + return rewriter.create(loc, cmp, lhs, rhs); }; } diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -59,7 +59,7 @@ /// operations may not have side-effects, as otherwise sinking (and hence /// duplicating them) is not legal. static bool isSinkingBeneficiary(Operation *op) { - return isa(op); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -17,7 +17,6 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Pass/Pass.h" @@ -56,7 +55,6 @@ linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); std_ext::registerModuleBufferizationExternalModels(registry); - mlir::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); } diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp --- a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp @@ -49,7 +49,8 @@ Value zero = rewriter.create(loc, floatZero); Value cmpRes = rewriter.create(loc, arith::CmpFPredicate::OGE, op.getOperand(), zero); - rewriter.replaceOpWithNewOp(op, cmpRes, positiveRes, negativeRes); + rewriter.replaceOpWithNewOp(op, cmpRes, positiveRes, + negativeRes); return success(); } diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -179,12 +179,12 @@ //----------------------------------------------------------------------------// static Value min(ImplicitLocOpBuilder &builder, Value a, Value b) { - return builder.create( + return builder.create( builder.create(arith::CmpFPredicate::OLT, a, b), a, b); } static Value max(ImplicitLocOpBuilder &builder, Value a, Value b) { - return builder.create( + return builder.create( builder.create(arith::CmpFPredicate::OGT, a, b), a, b); } @@ -311,7 +311,7 @@ Value reciprocal = builder.create(one, abs); Value compare = builder.create(arith::CmpFPredicate::OLT, abs, reciprocal); - Value x = builder.create(compare, abs, reciprocal); + Value x = builder.create(compare, abs, reciprocal); // Perform the Taylor series approximation for atan over the range // [-1.0, 1.0]. @@ -328,7 +328,7 @@ // Remap the solution for over [0.0, 1.0] to [0.0, inf] auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape); Value sub = builder.create(halfPi, p); - Value select = builder.create(compare, p, sub); + Value select = builder.create(compare, p, sub); // Correct for signing of the input. rewriter.replaceOpWithNewOp(op, select, operand); @@ -371,11 +371,11 @@ auto subPi = builder.create(atan, pi); auto atanGt = builder.create(arith::CmpFPredicate::OGT, atan, zero); - auto flippedAtan = builder.create(atanGt, subPi, addPi); + auto flippedAtan = builder.create(atanGt, subPi, addPi); // Determine whether to directly use atan or use the 180 degree flip auto xGt = builder.create(arith::CmpFPredicate::OGT, x, zero); - Value result = builder.create(xGt, atan, flippedAtan); + Value result = builder.create(xGt, atan, flippedAtan); // Handle x = 0, y > 0 Value xZero = @@ -383,22 +383,22 @@ Value yGt = builder.create(arith::CmpFPredicate::OGT, y, zero); Value isHalfPi = builder.create(xZero, yGt); auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape); - result = builder.create(isHalfPi, halfPi, result); + result = builder.create(isHalfPi, halfPi, result); // Handle x = 0, y < 0 Value yLt = builder.create(arith::CmpFPredicate::OLT, y, zero); Value isNegativeHalfPiPi = builder.create(xZero, yLt); auto negativeHalfPiPi = broadcast(builder, f32Cst(builder, -1.57079632679f), shape); - result = - builder.create(isNegativeHalfPiPi, negativeHalfPiPi, result); + result = builder.create(isNegativeHalfPiPi, negativeHalfPiPi, + result); // Handle x = 0, y = 0; Value yZero = builder.create(arith::CmpFPredicate::OEQ, y, zero); Value isNan = builder.create(xZero, yZero); Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape); - result = builder.create(isNan, cstNan, result); + result = builder.create(isNan, cstNan, result); rewriter.replaceOp(op, result); return success(); @@ -475,8 +475,8 @@ q = builder.create(x2, q, beta0); // Divide the numerator by the denominator. - Value res = builder.create(tinyMask, x, - builder.create(p, q)); + Value res = builder.create( + tinyMask, x, builder.create(p, q)); rewriter.replaceOp(op, res); @@ -561,11 +561,11 @@ // } else { x = x - 1.0; } Value mask = builder.create(arith::CmpFPredicate::OLT, x, cstCephesSQRTHF); - Value tmp = builder.create(mask, x, cstZero); + Value tmp = builder.create(mask, x, cstZero); x = builder.create(x, cstOne); e = builder.create( - e, builder.create(mask, cstOne, cstZero)); + e, builder.create(mask, cstOne, cstZero)); x = builder.create(x, tmp); Value x2 = builder.create(x, x); @@ -605,11 +605,11 @@ // • x == 0 -> -INF // • x < 0 -> NAN // • x == +INF -> +INF - Value aproximation = builder.create( + Value aproximation = builder.create( zeroMask, cstMinusInf, - builder.create( + builder.create( invalidMask, cstNan, - builder.create(posInfMask, cstPosInf, x))); + builder.create(posInfMask, cstPosInf, x))); rewriter.replaceOp(op, aproximation); @@ -683,7 +683,7 @@ Value logLarge = builder.create( x, builder.create( logU, builder.create(u, cstOne))); - Value approximation = builder.create( + Value approximation = builder.create( builder.create(uSmall, uInf), x, logLarge); rewriter.replaceOp(op, approximation); return success(); @@ -765,7 +765,8 @@ Value isNegativeArg = builder.create(arith::CmpFPredicate::OLT, op.getOperand(), zero); Value negArg = builder.create(op.getOperand()); - Value x = builder.create(isNegativeArg, negArg, op.getOperand()); + Value x = + builder.create(isNegativeArg, negArg, op.getOperand()); Value offset = offsets[0]; Value p[polyDegree + 1]; @@ -781,11 +782,13 @@ isLessThanBound[j] = builder.create(arith::CmpFPredicate::OLT, x, bounds[j]); for (int i = 0; i <= polyDegree; ++i) { - p[i] = builder.create(isLessThanBound[j], p[i], pp[j + 1][i]); - q[i] = builder.create(isLessThanBound[j], q[i], qq[j + 1][i]); + p[i] = builder.create(isLessThanBound[j], p[i], + pp[j + 1][i]); + q[i] = builder.create(isLessThanBound[j], q[i], + qq[j + 1][i]); } - offset = - builder.create(isLessThanBound[j], offset, offsets[j + 1]); + offset = builder.create(isLessThanBound[j], offset, + offsets[j + 1]); } isLessThanBound[intervalsCount - 1] = builder.create( arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]); @@ -794,12 +797,13 @@ Value qPoly = makePolynomialCalculation(builder, q, x); Value rationalPoly = builder.create(pPoly, qPoly); Value formula = builder.create(offset, rationalPoly); - formula = builder.create(isLessThanBound[intervalsCount - 1], - formula, one); + formula = builder.create(isLessThanBound[intervalsCount - 1], + formula, one); // erf is odd function: erf(x) = -erf(-x). Value negFormula = builder.create(formula); - Value res = builder.create(isNegativeArg, negFormula, formula); + Value res = + builder.create(isNegativeArg, negFormula, formula); rewriter.replaceOp(op, res); @@ -917,14 +921,14 @@ builder.create(arith::CmpFPredicate::OGT, x, zerof32Const); Value isComputable = builder.create(rightBound, leftBound); - expY = builder.create( + expY = builder.create( isNegInfinityX, zerof32Const, - builder.create( + builder.create( isPosInfinityX, constPosInfinity, - builder.create(isComputable, expY, - builder.create(isPostiveX, - constPosInfinity, - underflow)))); + builder.create( + isComputable, expY, + builder.create(isPostiveX, constPosInfinity, + underflow)))); rewriter.replaceOp(op, expY); @@ -981,9 +985,10 @@ // (u - 1) * (x / ~x) Value expm1 = builder.create( uMinusOne, builder.create(x, logU)); - expm1 = builder.create(isInf, u, expm1); - Value approximation = builder.create( - uEqOne, x, builder.create(uMinusOneEqNegOne, cstNegOne, expm1)); + expm1 = builder.create(isInf, u, expm1); + Value approximation = builder.create( + uEqOne, x, + builder.create(uMinusOneEqNegOne, cstNegOne, expm1)); rewriter.replaceOp(op, approximation); return success(); } @@ -1053,7 +1058,7 @@ }; auto select = [&](Value cond, Value t, Value f) -> Value { - return builder.create(cond, t, f); + return builder.create(cond, t, f); }; auto fmla = [&](Value a, Value b, Value c) { @@ -1189,7 +1194,8 @@ // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if // x is zero or a positive denormalized float (equivalent to flushing positive // denormalized inputs to zero). - Value res = builder.create(notNormalFiniteMask, yApprox, yNewton); + Value res = + builder.create(notNormalFiniteMask, yApprox, yNewton); rewriter.replaceOp(op, res); return success(); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -67,7 +67,7 @@ Value lhs = genericOp.getCurrentValue(); Value rhs = op.value(); Value cmp = bodyBuilder.create(loc, predicate, lhs, rhs); - Value select = bodyBuilder.create(loc, cmp, lhs, rhs); + Value select = bodyBuilder.create(loc, cmp, lhs, rhs); bodyBuilder.create(loc, select); rewriter.replaceOp(op, genericOp.getResult()); diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1306,8 +1306,8 @@ if (trueVal == falseVal) results[it.index()] = trueVal; else - results[it.index()] = - rewriter.create(op.getLoc(), cond, trueVal, falseVal); + results[it.index()] = rewriter.create( + op.getLoc(), cond, trueVal, falseVal); } rewriter.replaceOp(op, results); diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -808,8 +808,8 @@ Value stepped = b.create(t.getLoc(), iv, forOp.getStep()); Value less = b.create(t.getLoc(), arith::CmpIPredicate::slt, forOp.getUpperBound(), stepped); - Value ub = - b.create(t.getLoc(), less, forOp.getUpperBound(), stepped); + Value ub = b.create(t.getLoc(), less, + forOp.getUpperBound(), stepped); // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses. auto newForOp = b.create(t.getLoc(), iv, ub, originalStep); diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -36,7 +36,6 @@ pm.addPass(createFuncBufferizePass()); pm.addPass(arith::createConstantBufferizePass()); pm.addPass(createTensorBufferizePass()); - pm.addPass(createStdBufferizePass()); pm.addPass(mlir::bufferization::createFinalizingBufferizePass()); pm.addPass(createLowerAffinePass()); pm.addPass(createConvertVectorToLLVMPass()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -787,8 +787,8 @@ // Test if this is a scalarized reduction. if (codegen.redVal) { if (codegen.curVecLength > 1) - rhs = rewriter.create(loc, codegen.curVecMask, rhs, - codegen.redVal); + rhs = rewriter.create(loc, codegen.curVecMask, rhs, + codegen.redVal); updateReduc(merger, codegen, rhs); return; } @@ -1276,7 +1276,7 @@ if (min) { Value cmp = rewriter.create( loc, arith::CmpIPredicate::ult, load, min); - min = rewriter.create(loc, cmp, load, min); + min = rewriter.create(loc, cmp, load, min); } else { min = load; } @@ -1363,7 +1363,7 @@ Value cmp = rewriter.create(loc, arith::CmpIPredicate::eq, op1, op2); Value add = rewriter.create(loc, op3, one); - operands.push_back(rewriter.create(loc, cmp, add, op3)); + operands.push_back(rewriter.create(loc, cmp, add, op3)); codegen.pidxs[tensor][idx] = whileOp->getResult(o++); } } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -304,23 +304,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// General helpers for comparison ops -//===----------------------------------------------------------------------===// - -// Return the type of the same shape (scalar, vector or tensor) containing i1. -static Type getI1SameShape(Type type) { - auto i1Type = IntegerType::get(type.getContext(), 1); - if (auto tensorType = type.dyn_cast()) - return RankedTensorType::get(tensorType.getShape(), i1Type); - if (type.isa()) - return UnrankedTensorType::get(i1Type); - if (auto vectorType = type.dyn_cast()) - return VectorType::get(vectorType.getShape(), i1Type, - vectorType.getNumScalableDims()); - return i1Type; -} - //===----------------------------------------------------------------------===// // CondBranchOp //===----------------------------------------------------------------------===// @@ -390,7 +373,7 @@ /// -> br ^bb1(A, ..., N) /// /// cond_br %cond, ^bb1(A), ^bb1(B) -/// -> %select = select %cond, A, B +/// -> %select = arith.select %cond, A, B /// br ^bb1(%select) /// struct SimplifyCondBranchIdenticalSuccessors @@ -426,7 +409,7 @@ if (std::get<0>(it) == std::get<1>(it)) mergedOperands.push_back(std::get<0>(it)); else - mergedOperands.push_back(rewriter.create( + mergedOperands.push_back(rewriter.create( condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); } @@ -697,173 +680,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// SelectOp -//===----------------------------------------------------------------------===// - -// Transforms a select of a boolean to arithmetic operations -// -// select %arg, %x, %y : i1 -// -// becomes -// -// and(%arg, %x) or and(!%arg, %y) -struct SelectI1Simplify : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SelectOp op, - PatternRewriter &rewriter) const override { - if (!op.getType().isInteger(1)) - return failure(); - - Value falseConstant = - rewriter.create(op.getLoc(), true, 1); - Value notCondition = rewriter.create( - op.getLoc(), op.getCondition(), falseConstant); - - Value trueVal = rewriter.create( - op.getLoc(), op.getCondition(), op.getTrueValue()); - Value falseVal = rewriter.create(op.getLoc(), notCondition, - op.getFalseValue()); - rewriter.replaceOpWithNewOp(op, trueVal, falseVal); - return success(); - } -}; - -// select %arg, %c1, %c0 => extui %arg -struct SelectToExtUI : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SelectOp op, - PatternRewriter &rewriter) const override { - // Cannot extui i1 to i1, or i1 to f32 - if (!op.getType().isa() || op.getType().isInteger(1)) - return failure(); - - // select %x, c1, %c0 => extui %arg - if (matchPattern(op.getTrueValue(), m_One())) - if (matchPattern(op.getFalseValue(), m_Zero())) { - rewriter.replaceOpWithNewOp(op, op.getType(), - op.getCondition()); - return success(); - } - - // select %x, c0, %c1 => extui (xor %arg, true) - if (matchPattern(op.getTrueValue(), m_Zero())) - if (matchPattern(op.getFalseValue(), m_One())) { - rewriter.replaceOpWithNewOp( - op, op.getType(), - rewriter.create( - op.getLoc(), op.getCondition(), - rewriter.create( - op.getLoc(), 1, op.getCondition().getType()))); - return success(); - } - - return failure(); - } -}; - -void SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.insert(context); -} - -OpFoldResult SelectOp::fold(ArrayRef operands) { - auto trueVal = getTrueValue(); - auto falseVal = getFalseValue(); - if (trueVal == falseVal) - return trueVal; - - auto condition = getCondition(); - - // select true, %0, %1 => %0 - if (matchPattern(condition, m_One())) - return trueVal; - - // select false, %0, %1 => %1 - if (matchPattern(condition, m_Zero())) - return falseVal; - - // select %x, true, false => %x - if (getType().isInteger(1)) - if (matchPattern(getTrueValue(), m_One())) - if (matchPattern(getFalseValue(), m_Zero())) - return condition; - - if (auto cmp = dyn_cast_or_null(condition.getDefiningOp())) { - auto pred = cmp.getPredicate(); - if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { - auto cmpLhs = cmp.getLhs(); - auto cmpRhs = cmp.getRhs(); - - // %0 = arith.cmpi eq, %arg0, %arg1 - // %1 = select %0, %arg0, %arg1 => %arg1 - - // %0 = arith.cmpi ne, %arg0, %arg1 - // %1 = select %0, %arg0, %arg1 => %arg0 - - if ((cmpLhs == trueVal && cmpRhs == falseVal) || - (cmpRhs == trueVal && cmpLhs == falseVal)) - return pred == arith::CmpIPredicate::ne ? trueVal : falseVal; - } - } - return nullptr; -} - -static void print(OpAsmPrinter &p, SelectOp op) { - p << " " << op.getOperands(); - p.printOptionalAttrDict(op->getAttrs()); - p << " : "; - if (ShapedType condType = op.getCondition().getType().dyn_cast()) - p << condType << ", "; - p << op.getType(); -} - -static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) { - Type conditionType, resultType; - SmallVector operands; - if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(resultType)) - return failure(); - - // Check for the explicit condition type if this is a masked tensor or vector. - if (succeeded(parser.parseOptionalComma())) { - conditionType = resultType; - if (parser.parseType(resultType)) - return failure(); - } else { - conditionType = parser.getBuilder().getI1Type(); - } - - result.addTypes(resultType); - return parser.resolveOperands(operands, - {conditionType, resultType, resultType}, - parser.getNameLoc(), result.operands); -} - -static LogicalResult verify(SelectOp op) { - Type conditionType = op.getCondition().getType(); - if (conditionType.isSignlessInteger(1)) - return success(); - - // If the result type is a vector or tensor, the type can be a mask with the - // same elements. - Type resultType = op.getType(); - if (!resultType.isa()) - return op.emitOpError() - << "expected condition to be a signless i1, but got " - << conditionType; - Type shapedConditionType = getI1SameShape(resultType); - if (conditionType != shapedConditionType) - return op.emitOpError() - << "expected condition type to have the same shape " - "as the result type, expected " - << shapedConditionType << ", but got " << conditionType; - return success(); -} - //===----------------------------------------------------------------------===// // SwitchOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp +++ /dev/null @@ -1,77 +0,0 @@ -//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h" - -#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/Operation.h" - -using namespace mlir; -using namespace mlir::bufferization; - -namespace mlir { -namespace { - -/// Bufferization of std.select. Just replace the operands. -struct SelectOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - return false; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - return false; - } - - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - return op->getOpResult(0) /*result*/; - } - - SmallVector - getAliasingOpOperand(Operation *op, OpResult opResult, - const BufferizationState &state) const { - return {&op->getOpOperand(1) /*true_value*/, - &op->getOpOperand(2) /*false_value*/}; - } - - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { - auto selectOp = cast(op); - // `getBuffer` introduces copies if an OpOperand bufferizes out-of-place. - // TODO: It would be more efficient to copy the result of the `select` op - // instead of its OpOperands. In the worst case, 2 copies are inserted at - // the moment (one for each tensor). When copying the op result, only one - // copy would be needed. - Value trueBuffer = - *state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/); - Value falseBuffer = - *state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/); - replaceOpWithNewBufferizedOp( - rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer); - return success(); - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { - return BufferRelation::None; - } -}; - -} // namespace -} // namespace mlir - -void mlir::registerBufferizableOpInterfaceExternalModels( - DialectRegistry ®istry) { - registry.addOpInterface(); -} diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ /dev/null @@ -1,48 +0,0 @@ -//===- Bufferize.cpp - Bufferization for std ops --------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements bufferization of std ops. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" -#include "PassDetail.h" -#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/SCF.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/StandardOps/Transforms/Passes.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" - -using namespace mlir; -using namespace mlir::bufferization; - -namespace { -struct StdBufferizePass : public StdBufferizeBase { - void runOnOperation() override { - std::unique_ptr options = - getPartialBufferizationOptions(); - options->addToDialectFilter(); - - if (failed(bufferizeOp(getOperation(), *options))) - signalPassFailure(); - } - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - mlir::registerBufferizableOpInterfaceExternalModels(registry); - } -}; -} // namespace - -std::unique_ptr mlir::createStdBufferizePass() { - return std::make_unique(); -} diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -1,6 +1,4 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms - BufferizableOpInterfaceImpl.cpp - Bufferize.cpp DecomposeCallGraphTypes.cpp FuncBufferize.cpp FuncConversions.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -698,7 +698,7 @@ rewriter.create(loc, rewriter.getIndexAttr(d)); Value val = rewriter.create(loc, arith::CmpIPredicate::slt, bnd, idx); - Value sel = rewriter.create(loc, val, trueVal, falseVal); + Value sel = rewriter.create(loc, val, trueVal, falseVal); auto pos = rewriter.getI64ArrayAttr(d); result = rewriter.create(loc, dstType, sel, result, pos); diff --git a/mlir/test/Analysis/test-match-reduction.mlir b/mlir/test/Analysis/test-match-reduction.mlir --- a/mlir/test/Analysis/test-match-reduction.mlir +++ b/mlir/test/Analysis/test-match-reduction.mlir @@ -52,7 +52,7 @@ outs(%out0t : tensor<4xf32>) { ^bb0(%in0: f32, %out0: f32): %cmp = arith.cmpf ogt, %in0, %out0 : f32 - %sel = select %cmp, %in0, %out0 : f32 + %sel = arith.select %cmp, %in0, %out0 : f32 linalg.yield %sel : f32 } -> tensor<4xf32> return diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir --- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir @@ -374,11 +374,11 @@ // CHECK-NEXT: %[[a:.*]] = arith.muli %{{.*}}, %[[cm1]] : index // CHECK-NEXT: %[[b:.*]] = arith.addi %[[a]], %{{.*}} : index // CHECK-NEXT: %[[c:.*]] = arith.cmpi sgt, %{{.*}}, %[[b]] : index -// CHECK-NEXT: %[[d:.*]] = select %[[c]], %{{.*}}, %[[b]] : index +// CHECK-NEXT: %[[d:.*]] = arith.select %[[c]], %{{.*}}, %[[b]] : index // CHECK-NEXT: %[[c10:.*]] = arith.constant 10 : index // CHECK-NEXT: %[[e:.*]] = arith.addi %{{.*}}, %[[c10]] : index // CHECK-NEXT: %[[f:.*]] = arith.cmpi slt, %{{.*}}, %[[e]] : index -// CHECK-NEXT: %[[g:.*]] = select %[[f]], %{{.*}}, %[[e]] : index +// CHECK-NEXT: %[[g:.*]] = arith.select %[[f]], %{{.*}}, %[[e]] : index // CHECK-NEXT: %[[c1_0:.*]] = arith.constant 1 : index // CHECK-NEXT: for %{{.*}} = %[[d]] to %[[g]] step %[[c1_0]] { // CHECK-NEXT: call @body2(%{{.*}}, %{{.*}}) : (index, index) -> () @@ -403,17 +403,17 @@ // CHECK-LABEL: func @min_reduction_tree // CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index // CHECK-NEXT: %[[c01:.+]] = arith.cmpi slt, %{{.*}}, %{{.*}} : index -// CHECK-NEXT: %[[r01:.+]] = select %[[c01]], %{{.*}}, %{{.*}} : index +// CHECK-NEXT: %[[r01:.+]] = arith.select %[[c01]], %{{.*}}, %{{.*}} : index // CHECK-NEXT: %[[c012:.+]] = arith.cmpi slt, %[[r01]], %{{.*}} : index -// CHECK-NEXT: %[[r012:.+]] = select %[[c012]], %[[r01]], %{{.*}} : index +// CHECK-NEXT: %[[r012:.+]] = arith.select %[[c012]], %[[r01]], %{{.*}} : index // CHECK-NEXT: %[[c0123:.+]] = arith.cmpi slt, %[[r012]], %{{.*}} : index -// CHECK-NEXT: %[[r0123:.+]] = select %[[c0123]], %[[r012]], %{{.*}} : index +// CHECK-NEXT: %[[r0123:.+]] = arith.select %[[c0123]], %[[r012]], %{{.*}} : index // CHECK-NEXT: %[[c01234:.+]] = arith.cmpi slt, %[[r0123]], %{{.*}} : index -// CHECK-NEXT: %[[r01234:.+]] = select %[[c01234]], %[[r0123]], %{{.*}} : index +// CHECK-NEXT: %[[r01234:.+]] = arith.select %[[c01234]], %[[r0123]], %{{.*}} : index // CHECK-NEXT: %[[c012345:.+]] = arith.cmpi slt, %[[r01234]], %{{.*}} : index -// CHECK-NEXT: %[[r012345:.+]] = select %[[c012345]], %[[r01234]], %{{.*}} : index +// CHECK-NEXT: %[[r012345:.+]] = arith.select %[[c012345]], %[[r01234]], %{{.*}} : index // CHECK-NEXT: %[[c0123456:.+]] = arith.cmpi slt, %[[r012345]], %{{.*}} : index -// CHECK-NEXT: %[[r0123456:.+]] = select %[[c0123456]], %[[r012345]], %{{.*}} : index +// CHECK-NEXT: %[[r0123456:.+]] = arith.select %[[c0123456]], %[[r012345]], %{{.*}} : index // CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index // CHECK-NEXT: for %{{.*}} = %[[c0]] to %[[r0123456]] step %[[c1]] { // CHECK-NEXT: call @body(%{{.*}}) : (index) -> () @@ -507,7 +507,7 @@ // CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index // CHECK-NEXT: %[[v1:.*]] = arith.cmpi slt, %[[v0]], %[[c0]] : index // CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v0]], %[[c42]] : index -// CHECK-NEXT: %[[v3:.*]] = select %[[v1]], %[[v2]], %[[v0]] : index +// CHECK-NEXT: %[[v3:.*]] = arith.select %[[v1]], %[[v2]], %[[v0]] : index %0 = affine.apply #mapmod (%arg0) return %0 : index } @@ -526,10 +526,10 @@ // CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index // CHECK-NEXT: %[[v0:.*]] = arith.cmpi slt, %{{.*}}, %[[c0]] : index // CHECK-NEXT: %[[v1:.*]] = arith.subi %[[cm1]], %{{.*}} : index -// CHECK-NEXT: %[[v2:.*]] = select %[[v0]], %[[v1]], %{{.*}} : index +// CHECK-NEXT: %[[v2:.*]] = arith.select %[[v0]], %[[v1]], %{{.*}} : index // CHECK-NEXT: %[[v3:.*]] = arith.divsi %[[v2]], %[[c42]] : index // CHECK-NEXT: %[[v4:.*]] = arith.subi %[[cm1]], %[[v3]] : index -// CHECK-NEXT: %[[v5:.*]] = select %[[v0]], %[[v4]], %[[v3]] : index +// CHECK-NEXT: %[[v5:.*]] = arith.select %[[v0]], %[[v4]], %[[v3]] : index %0 = affine.apply #mapfloordiv (%arg0) return %0 : index } @@ -549,11 +549,11 @@ // CHECK-NEXT: %[[v0:.*]] = arith.cmpi sle, %{{.*}}, %[[c0]] : index // CHECK-NEXT: %[[v1:.*]] = arith.subi %[[c0]], %{{.*}} : index // CHECK-NEXT: %[[v2:.*]] = arith.subi %{{.*}}, %[[c1]] : index -// CHECK-NEXT: %[[v3:.*]] = select %[[v0]], %[[v1]], %[[v2]] : index +// CHECK-NEXT: %[[v3:.*]] = arith.select %[[v0]], %[[v1]], %[[v2]] : index // CHECK-NEXT: %[[v4:.*]] = arith.divsi %[[v3]], %[[c42]] : index // CHECK-NEXT: %[[v5:.*]] = arith.subi %[[c0]], %[[v4]] : index // CHECK-NEXT: %[[v6:.*]] = arith.addi %[[v4]], %[[c1]] : index -// CHECK-NEXT: %[[v7:.*]] = select %[[v0]], %[[v5]], %[[v6]] : index +// CHECK-NEXT: %[[v7:.*]] = arith.select %[[v0]], %[[v5]], %[[v6]] : index %0 = affine.apply #mapceildiv (%arg0) return %0 : index } @@ -652,7 +652,7 @@ // CHECK: %[[neg2:.*]] = arith.muli %[[ARG0]], %[[Cm2:.*]] // CHECK: %[[second:.*]] = arith.addi %[[ARG1]], %[[neg2]] // CHECK: %[[cmp:.*]] = arith.cmpi slt, %[[first]], %[[second]] - // CHECK: select %[[cmp]], %[[first]], %[[second]] + // CHECK: arith.select %[[cmp]], %[[first]], %[[second]] %0 = affine.min affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1) return %0 : index } @@ -667,7 +667,7 @@ // CHECK: %[[neg2:.*]] = arith.muli %[[ARG0]], %[[Cm2:.*]] // CHECK: %[[second:.*]] = arith.addi %[[ARG1]], %[[neg2]] // CHECK: %[[cmp:.*]] = arith.cmpi sgt, %[[first]], %[[second]] - // CHECK: select %[[cmp]], %[[first]], %[[second]] + // CHECK: arith.select %[[cmp]], %[[first]], %[[second]] %0 = affine.max affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1) return %0 : index } diff --git a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir --- a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir @@ -374,3 +374,12 @@ %0 = arith.cmpi ult, %arg0, %arg1 : vector<4x3xi32> std.return } + +// ----- + +// CHECK-LABEL: @select +func @select(%arg0 : i1, %arg1 : i32, %arg2 : i32) -> i32 { + // CHECK: = llvm.select %arg0, %arg1, %arg2 : i1, i32 + %0 = arith.select %arg0, %arg1, %arg2 : i32 + return %0 : i32 +} diff --git a/mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir b/mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir --- a/mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir +++ b/mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir @@ -194,3 +194,19 @@ arith.bitcast %arg0 : vector<2x4xf32> to vector<2x4xi32> return } + +// ----- + +// CHECK-LABEL: func @select_2d( +func @select_2d(%arg0 : vector<4x3xi1>, %arg1 : vector<4x3xi32>, %arg2 : vector<4x3xi32>) { + // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %arg0 + // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %arg1 + // CHECK: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %arg2 + // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xi1>> + // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xi32>> + // CHECK: %[[EXTRACT3:.*]] = llvm.extractvalue %[[ARG2]][0] : !llvm.array<4 x vector<3xi32>> + // CHECK: %[[SELECT:.*]] = llvm.select %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]] : vector<3xi1>, vector<3xi32> + // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[SELECT]], %{{.*}}[0] : !llvm.array<4 x vector<3xi32>> + %0 = arith.select %arg0, %arg1, %arg2 : vector<4x3xi1>, vector<4x3xi32> + std.return +} diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -892,3 +892,21 @@ } } // end module + +// ----- + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, {}> +} { + +// CHECK-LABEL: @select +func @select(%arg0 : i32, %arg1 : i32) { + %0 = arith.cmpi sle, %arg0, %arg1 : i32 + // CHECK: spv.Select + %1 = arith.select %0, %arg0, %arg1 : i32 + return +} + +} // end module diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -87,9 +87,9 @@ // CHECK: %[[LHS_IS_INFINITE:.*]] = arith.ori %[[LHS_REAL_INFINITE]], %[[LHS_IMAG_INFINITE]] : i1 // CHECK: %[[INF_NUM_FINITE_DENOM:.*]] = arith.andi %[[LHS_IS_INFINITE]], %[[RHS_IS_FINITE]] : i1 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[LHS_REAL_IS_INF:.*]] = select %[[LHS_REAL_INFINITE]], %[[ONE]], %[[ZERO]] : f32 +// CHECK: %[[LHS_REAL_IS_INF:.*]] = arith.select %[[LHS_REAL_INFINITE]], %[[ONE]], %[[ZERO]] : f32 // CHECK: %[[LHS_REAL_IS_INF_WITH_SIGN:.*]] = math.copysign %[[LHS_REAL_IS_INF]], %[[LHS_REAL]] : f32 -// CHECK: %[[LHS_IMAG_IS_INF:.*]] = select %[[LHS_IMAG_INFINITE]], %[[ONE]], %[[ZERO]] : f32 +// CHECK: %[[LHS_IMAG_IS_INF:.*]] = arith.select %[[LHS_IMAG_INFINITE]], %[[ONE]], %[[ZERO]] : f32 // CHECK: %[[LHS_IMAG_IS_INF_WITH_SIGN:.*]] = math.copysign %[[LHS_IMAG_IS_INF]], %[[LHS_IMAG]] : f32 // CHECK: %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL_IS_INF_WITH_SIGN]], %[[RHS_REAL]] : f32 // CHECK: %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG_IS_INF_WITH_SIGN]], %[[RHS_IMAG]] : f32 @@ -108,9 +108,9 @@ // CHECK: %[[RHS_IMAG_INFINITE:.*]] = arith.cmpf oeq, %[[RHS_IMAG_ABS]], %[[INF]] : f32 // CHECK: %[[RHS_IS_INFINITE:.*]] = arith.ori %[[RHS_REAL_INFINITE]], %[[RHS_IMAG_INFINITE]] : i1 // CHECK: %[[FINITE_NUM_INFINITE_DENOM:.*]] = arith.andi %[[LHS_IS_FINITE]], %[[RHS_IS_INFINITE]] : i1 -// CHECK: %[[RHS_REAL_IS_INF:.*]] = select %[[RHS_REAL_INFINITE]], %[[ONE]], %[[ZERO]] : f32 +// CHECK: %[[RHS_REAL_IS_INF:.*]] = arith.select %[[RHS_REAL_INFINITE]], %[[ONE]], %[[ZERO]] : f32 // CHECK: %[[RHS_REAL_IS_INF_WITH_SIGN:.*]] = math.copysign %[[RHS_REAL_IS_INF]], %[[RHS_REAL]] : f32 -// CHECK: %[[RHS_IMAG_IS_INF:.*]] = select %[[RHS_IMAG_INFINITE]], %[[ONE]], %[[ZERO]] : f32 +// CHECK: %[[RHS_IMAG_IS_INF:.*]] = arith.select %[[RHS_IMAG_INFINITE]], %[[ONE]], %[[ZERO]] : f32 // CHECK: %[[RHS_IMAG_IS_INF_WITH_SIGN:.*]] = math.copysign %[[RHS_IMAG_IS_INF]], %[[RHS_IMAG]] : f32 // CHECK: %[[RHS_REAL_IS_INF_WITH_SIGN_TIMES_LHS_REAL:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL_IS_INF_WITH_SIGN]] : f32 // CHECK: %[[RHS_IMAG_IS_INF_WITH_SIGN_TIMES_LHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG_IS_INF_WITH_SIGN]] : f32 @@ -122,19 +122,19 @@ // CHECK: %[[RESULT_IMAG_4:.*]] = arith.mulf %[[ZERO]], %[[ZERO_MULTIPLICATOR_2]] : f32 // CHECK: %[[REAL_ABS_SMALLER_THAN_IMAG_ABS:.*]] = arith.cmpf olt, %[[RHS_REAL_ABS]], %[[RHS_IMAG_ABS]] : f32 -// CHECK: %[[RESULT_REAL:.*]] = select %[[REAL_ABS_SMALLER_THAN_IMAG_ABS]], %[[RESULT_REAL_1]], %[[RESULT_REAL_2]] : f32 -// CHECK: %[[RESULT_IMAG:.*]] = select %[[REAL_ABS_SMALLER_THAN_IMAG_ABS]], %[[RESULT_IMAG_1]], %[[RESULT_IMAG_2]] : f32 -// CHECK: %[[RESULT_REAL_SPECIAL_CASE_3:.*]] = select %[[FINITE_NUM_INFINITE_DENOM]], %[[RESULT_REAL_4]], %[[RESULT_REAL]] : f32 -// CHECK: %[[RESULT_IMAG_SPECIAL_CASE_3:.*]] = select %[[FINITE_NUM_INFINITE_DENOM]], %[[RESULT_IMAG_4]], %[[RESULT_IMAG]] : f32 -// CHECK: %[[RESULT_REAL_SPECIAL_CASE_2:.*]] = select %[[INF_NUM_FINITE_DENOM]], %[[RESULT_REAL_3]], %[[RESULT_REAL_SPECIAL_CASE_3]] : f32 -// CHECK: %[[RESULT_IMAG_SPECIAL_CASE_2:.*]] = select %[[INF_NUM_FINITE_DENOM]], %[[RESULT_IMAG_3]], %[[RESULT_IMAG_SPECIAL_CASE_3]] : f32 -// CHECK: %[[RESULT_REAL_SPECIAL_CASE_1:.*]] = select %[[RESULT_IS_INFINITY]], %[[INFINITY_RESULT_REAL]], %[[RESULT_REAL_SPECIAL_CASE_2]] : f32 -// CHECK: %[[RESULT_IMAG_SPECIAL_CASE_1:.*]] = select %[[RESULT_IS_INFINITY]], %[[INFINITY_RESULT_IMAG]], %[[RESULT_IMAG_SPECIAL_CASE_2]] : f32 +// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[REAL_ABS_SMALLER_THAN_IMAG_ABS]], %[[RESULT_REAL_1]], %[[RESULT_REAL_2]] : f32 +// CHECK: %[[RESULT_IMAG:.*]] = arith.select %[[REAL_ABS_SMALLER_THAN_IMAG_ABS]], %[[RESULT_IMAG_1]], %[[RESULT_IMAG_2]] : f32 +// CHECK: %[[RESULT_REAL_SPECIAL_CASE_3:.*]] = arith.select %[[FINITE_NUM_INFINITE_DENOM]], %[[RESULT_REAL_4]], %[[RESULT_REAL]] : f32 +// CHECK: %[[RESULT_IMAG_SPECIAL_CASE_3:.*]] = arith.select %[[FINITE_NUM_INFINITE_DENOM]], %[[RESULT_IMAG_4]], %[[RESULT_IMAG]] : f32 +// CHECK: %[[RESULT_REAL_SPECIAL_CASE_2:.*]] = arith.select %[[INF_NUM_FINITE_DENOM]], %[[RESULT_REAL_3]], %[[RESULT_REAL_SPECIAL_CASE_3]] : f32 +// CHECK: %[[RESULT_IMAG_SPECIAL_CASE_2:.*]] = arith.select %[[INF_NUM_FINITE_DENOM]], %[[RESULT_IMAG_3]], %[[RESULT_IMAG_SPECIAL_CASE_3]] : f32 +// CHECK: %[[RESULT_REAL_SPECIAL_CASE_1:.*]] = arith.select %[[RESULT_IS_INFINITY]], %[[INFINITY_RESULT_REAL]], %[[RESULT_REAL_SPECIAL_CASE_2]] : f32 +// CHECK: %[[RESULT_IMAG_SPECIAL_CASE_1:.*]] = arith.select %[[RESULT_IS_INFINITY]], %[[INFINITY_RESULT_IMAG]], %[[RESULT_IMAG_SPECIAL_CASE_2]] : f32 // CHECK: %[[RESULT_REAL_IS_NAN:.*]] = arith.cmpf uno, %[[RESULT_REAL]], %[[ZERO]] : f32 // CHECK: %[[RESULT_IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[RESULT_IMAG]], %[[ZERO]] : f32 // CHECK: %[[RESULT_IS_NAN:.*]] = arith.andi %[[RESULT_REAL_IS_NAN]], %[[RESULT_IMAG_IS_NAN]] : i1 -// CHECK: %[[RESULT_REAL_WITH_SPECIAL_CASES:.*]] = select %[[RESULT_IS_NAN]], %[[RESULT_REAL_SPECIAL_CASE_1]], %[[RESULT_REAL]] : f32 -// CHECK: %[[RESULT_IMAG_WITH_SPECIAL_CASES:.*]] = select %[[RESULT_IS_NAN]], %[[RESULT_IMAG_SPECIAL_CASE_1]], %[[RESULT_IMAG]] : f32 +// CHECK: %[[RESULT_REAL_WITH_SPECIAL_CASES:.*]] = arith.select %[[RESULT_IS_NAN]], %[[RESULT_REAL_SPECIAL_CASE_1]], %[[RESULT_REAL]] : f32 +// CHECK: %[[RESULT_IMAG_WITH_SPECIAL_CASES:.*]] = arith.select %[[RESULT_IS_NAN]], %[[RESULT_IMAG_SPECIAL_CASE_1]], %[[RESULT_IMAG]] : f32 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex // CHECK: return %[[RESULT]] : complex @@ -253,18 +253,18 @@ // CHECK: %[[RHS_IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[RHS_IMAG]], %[[RHS_IMAG]] : f32 // CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[LHS_REAL_IS_INF_FLOAT:.*]] = select %[[LHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32 +// CHECK: %[[LHS_REAL_IS_INF_FLOAT:.*]] = arith.select %[[LHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32 // CHECK: %[[TMP:.*]] = math.copysign %[[LHS_REAL_IS_INF_FLOAT]], %[[LHS_REAL]] : f32 -// CHECK: %[[LHS_REAL1:.*]] = select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_REAL]] : f32 -// CHECK: %[[LHS_IMAG_IS_INF_FLOAT:.*]] = select %[[LHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32 +// CHECK: %[[LHS_REAL1:.*]] = arith.select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_REAL]] : f32 +// CHECK: %[[LHS_IMAG_IS_INF_FLOAT:.*]] = arith.select %[[LHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32 // CHECK: %[[TMP:.*]] = math.copysign %[[LHS_IMAG_IS_INF_FLOAT]], %[[LHS_IMAG]] : f32 -// CHECK: %[[LHS_IMAG1:.*]] = select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_IMAG]] : f32 +// CHECK: %[[LHS_IMAG1:.*]] = arith.select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_IMAG]] : f32 // CHECK: %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN:.*]] = arith.andi %[[LHS_IS_INF]], %[[RHS_REAL_IS_NAN]] : i1 // CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_REAL]] : f32 -// CHECK: %[[RHS_REAL1:.*]] = select %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL]] : f32 +// CHECK: %[[RHS_REAL1:.*]] = arith.select %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL]] : f32 // CHECK: %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN:.*]] = arith.andi %[[LHS_IS_INF]], %[[RHS_IMAG_IS_NAN]] : i1 // CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_IMAG]] : f32 -// CHECK: %[[RHS_IMAG1:.*]] = select %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG]] : f32 +// CHECK: %[[RHS_IMAG1:.*]] = arith.select %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG]] : f32 // Case 2. RHS_REAL or RHS_IMAG are infinite. // CHECK: %[[RHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[RHS_REAL_ABS]], %[[INF]] : f32 @@ -272,18 +272,18 @@ // CHECK: %[[RHS_IS_INF:.*]] = arith.ori %[[RHS_REAL_IS_INF]], %[[RHS_IMAG_IS_INF]] : i1 // CHECK: %[[LHS_REAL_IS_NAN:.*]] = arith.cmpf uno, %[[LHS_REAL1]], %[[LHS_REAL1]] : f32 // CHECK: %[[LHS_IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[LHS_IMAG1]], %[[LHS_IMAG1]] : f32 -// CHECK: %[[RHS_REAL_IS_INF_FLOAT:.*]] = select %[[RHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32 +// CHECK: %[[RHS_REAL_IS_INF_FLOAT:.*]] = arith.select %[[RHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32 // CHECK: %[[TMP:.*]] = math.copysign %[[RHS_REAL_IS_INF_FLOAT]], %[[RHS_REAL1]] : f32 -// CHECK: %[[RHS_REAL2:.*]] = select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_REAL1]] : f32 -// CHECK: %[[RHS_IMAG_IS_INF_FLOAT:.*]] = select %[[RHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32 +// CHECK: %[[RHS_REAL2:.*]] = arith.select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_REAL1]] : f32 +// CHECK: %[[RHS_IMAG_IS_INF_FLOAT:.*]] = arith.select %[[RHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32 // CHECK: %[[TMP:.*]] = math.copysign %[[RHS_IMAG_IS_INF_FLOAT]], %[[RHS_IMAG1]] : f32 -// CHECK: %[[RHS_IMAG2:.*]] = select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_IMAG1]] : f32 +// CHECK: %[[RHS_IMAG2:.*]] = arith.select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_IMAG1]] : f32 // CHECK: %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN:.*]] = arith.andi %[[RHS_IS_INF]], %[[LHS_REAL_IS_NAN]] : i1 // CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_REAL1]] : f32 -// CHECK: %[[LHS_REAL2:.*]] = select %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL1]] : f32 +// CHECK: %[[LHS_REAL2:.*]] = arith.select %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL1]] : f32 // CHECK: %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN:.*]] = arith.andi %[[RHS_IS_INF]], %[[LHS_IMAG_IS_NAN]] : i1 // CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_IMAG1]] : f32 -// CHECK: %[[LHS_IMAG2:.*]] = select %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG1]] : f32 +// CHECK: %[[LHS_IMAG2:.*]] = arith.select %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG1]] : f32 // CHECK: %[[RECALC:.*]] = arith.ori %[[LHS_IS_INF]], %[[RHS_IS_INF]] : i1 // Case 3. One of the pairwise products of left hand side with right hand side @@ -300,16 +300,16 @@ // CHECK: %[[IS_SPECIAL_CASE3:.*]] = arith.andi %[[IS_SPECIAL_CASE2]], %[[NOT_RECALC]] : i1 // CHECK: %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[LHS_REAL_IS_NAN]] : i1 // CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_REAL2]] : f32 -// CHECK: %[[LHS_REAL3:.*]] = select %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL2]] : f32 +// CHECK: %[[LHS_REAL3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL2]] : f32 // CHECK: %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[LHS_IMAG_IS_NAN]] : i1 // CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_IMAG2]] : f32 -// CHECK: %[[LHS_IMAG3:.*]] = select %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG2]] : f32 +// CHECK: %[[LHS_IMAG3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG2]] : f32 // CHECK: %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[RHS_REAL_IS_NAN]] : i1 // CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_REAL2]] : f32 -// CHECK: %[[RHS_REAL3:.*]] = select %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL2]] : f32 +// CHECK: %[[RHS_REAL3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL2]] : f32 // CHECK: %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[RHS_IMAG_IS_NAN]] : i1 // CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_IMAG2]] : f32 -// CHECK: %[[RHS_IMAG3:.*]] = select %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG2]] : f32 +// CHECK: %[[RHS_IMAG3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG2]] : f32 // CHECK: %[[RECALC2:.*]] = arith.ori %[[RECALC]], %[[IS_SPECIAL_CASE3]] : i1 // CHECK: %[[RECALC3:.*]] = arith.andi %[[IS_NAN]], %[[RECALC2]] : i1 @@ -318,14 +318,14 @@ // CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG3]], %[[RHS_IMAG3]] : f32 // CHECK: %[[NEW_REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] : f32 // CHECK: %[[NEW_REAL_TIMES_INF:.*]] = arith.mulf %[[INF]], %[[NEW_REAL]] : f32 -// CHECK: %[[FINAL_REAL:.*]] = select %[[RECALC3]], %[[NEW_REAL_TIMES_INF]], %[[REAL]] : f32 +// CHECK: %[[FINAL_REAL:.*]] = arith.select %[[RECALC3]], %[[NEW_REAL_TIMES_INF]], %[[REAL]] : f32 // Recalculate imag part. // CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG3]], %[[RHS_REAL3]] : f32 // CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL3]], %[[RHS_IMAG3]] : f32 // CHECK: %[[NEW_IMAG:.*]] = arith.addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] : f32 // CHECK: %[[NEW_IMAG_TIMES_INF:.*]] = arith.mulf %[[INF]], %[[NEW_IMAG]] : f32 -// CHECK: %[[FINAL_IMAG:.*]] = select %[[RECALC3]], %[[NEW_IMAG_TIMES_INF]], %[[IMAG]] : f32 +// CHECK: %[[FINAL_IMAG:.*]] = arith.select %[[RECALC3]], %[[NEW_IMAG_TIMES_INF]], %[[IMAG]] : f32 // CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex // CHECK: return %[[RESULT]] : complex @@ -379,7 +379,7 @@ // CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32 // CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32 // CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex -// CHECK: %[[RESULT:.*]] = select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex +// CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex // CHECK: return %[[RESULT]] : complex // CHECK-LABEL: func @complex_sub diff --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir --- a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir @@ -92,7 +92,7 @@ // CHECK: combiner // CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) // CHECK: %[[CMP:.*]] = arith.cmpf oge, %[[ARG0]], %[[ARG1]] -// CHECK: %[[RES:.*]] = select %[[CMP]], %[[ARG0]], %[[ARG1]] +// CHECK: %[[RES:.*]] = arith.select %[[CMP]], %[[ARG0]], %[[ARG1]] // CHECK: omp.yield(%[[RES]] : f32) // CHECK-NOT: atomic @@ -108,7 +108,7 @@ scf.reduce(%one) : f32 { ^bb0(%lhs : f32, %rhs: f32): %cmp = arith.cmpf oge, %lhs, %rhs : f32 - %res = select %cmp, %lhs, %rhs : f32 + %res = arith.select %cmp, %lhs, %rhs : f32 scf.reduce.return %res : f32 } } @@ -126,7 +126,7 @@ // CHECK: combiner // CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) // CHECK: %[[CMP:.*]] = arith.cmpf oge, %[[ARG0]], %[[ARG1]] -// CHECK: %[[RES:.*]] = select %[[CMP]], %[[ARG0]], %[[ARG1]] +// CHECK: %[[RES:.*]] = arith.select %[[CMP]], %[[ARG0]], %[[ARG1]] // CHECK: omp.yield(%[[RES]] : f32) // CHECK-NOT: atomic @@ -140,7 +140,7 @@ // CHECK: combiner // CHECK: ^{{.*}}(%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64) // CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[ARG0]], %[[ARG1]] -// CHECK: %[[RES:.*]] = select %[[CMP]], %[[ARG1]], %[[ARG0]] +// CHECK: %[[RES:.*]] = arith.select %[[CMP]], %[[ARG1]], %[[ARG0]] // CHECK: omp.yield(%[[RES]] : i64) // CHECK: atomic @@ -172,7 +172,7 @@ scf.reduce(%one) : f32 { ^bb0(%lhs : f32, %rhs: f32): %cmp = arith.cmpf oge, %lhs, %rhs : f32 - %res = select %cmp, %lhs, %rhs : f32 + %res = arith.select %cmp, %lhs, %rhs : f32 scf.reduce.return %res : f32 } // CHECK: arith.fptosi @@ -181,7 +181,7 @@ scf.reduce(%1) : i64 { ^bb1(%lhs: i64, %rhs: i64): %cmp = arith.cmpi slt, %lhs, %rhs : i64 - %res = select %cmp, %rhs, %lhs : i64 + %res = arith.select %cmp, %rhs, %lhs : i64 scf.reduce.return %res : i64 } // CHECK: omp.yield diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -366,9 +366,9 @@ // CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex> // CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex> // CHECK: %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index -// CHECK: %[[LARGER_DIM:.*]] = select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index +// CHECK: %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index // CHECK: %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index -// CHECK: %[[MAX_RANK:.*]] = select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index +// CHECK: %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index // CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index // CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index // CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index @@ -382,7 +382,7 @@ // CHECK: %[[IDX0:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF0]] : index // CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex> // CHECK: %[[DIM0_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_0:.*]], %[[C1_0]] : index -// CHECK: %[[MAX_DIM0:.*]] = select %[[DIM0_IS_1]], %[[C1_0]], %[[EXTRACTED_0]] : index +// CHECK: %[[MAX_DIM0:.*]] = arith.select %[[DIM0_IS_1]], %[[C1_0]], %[[EXTRACTED_0]] : index // CHECK: } // CHECK: %[[VAL_28:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index // CHECK: %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) { @@ -391,7 +391,7 @@ // CHECK: %[[IDX1:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF1]] : index // CHECK: %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex> // CHECK: %[[DIM1_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_1:.*]], %[[C1_0]] : index -// CHECK: %[[MAX_DIM1:.*]] = select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index +// CHECK: %[[MAX_DIM1:.*]] = arith.select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index // CHECK: } // CHECK: %[[VAL_36:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index // CHECK: %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) { @@ -400,7 +400,7 @@ // CHECK: %[[IDX2:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF2]] : index // CHECK: %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex> // CHECK: %[[DIM2_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_2]], %[[C1_0]] : index -// CHECK: %[[MAX_DIM2:.*]] = select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index +// CHECK: %[[MAX_DIM2:.*]] = arith.select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index // CHECK: } // CHECK: %[[OUT_BOUND_0:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index // CHECK: %[[REDUCTION_0:.*]] = scf.if %[[OUT_BOUND_0]] -> (i1) { @@ -456,9 +456,9 @@ // CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex> // CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex> // CHECK: %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index -// CHECK: %[[LARGER_DIM:.*]] = select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index +// CHECK: %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index // CHECK: %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index -// CHECK: %[[MAX_RANK:.*]] = select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index +// CHECK: %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index // CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index // CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index // CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index @@ -472,7 +472,7 @@ // CHECK: %[[IDX0:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF0]] : index // CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex> // CHECK: %[[DIM0_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_0:.*]], %[[C1_0]] : index -// CHECK: %[[MAX_DIM0:.*]] = select %[[DIM0_IS_1]], %[[C1_0]], %[[EXTRACTED_0]] : index +// CHECK: %[[MAX_DIM0:.*]] = arith.select %[[DIM0_IS_1]], %[[C1_0]], %[[EXTRACTED_0]] : index // CHECK: } // CHECK: %[[VAL_28:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index // CHECK: %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) { @@ -481,7 +481,7 @@ // CHECK: %[[IDX1:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF1]] : index // CHECK: %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex> // CHECK: %[[DIM1_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_1:.*]], %[[C1_0]] : index -// CHECK: %[[MAX_DIM1:.*]] = select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index +// CHECK: %[[MAX_DIM1:.*]] = arith.select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index // CHECK: } // CHECK: %[[VAL_36:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index // CHECK: %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) { @@ -490,7 +490,7 @@ // CHECK: %[[IDX2:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF2]] : index // CHECK: %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex> // CHECK: %[[DIM2_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_2]], %[[C1_0]] : index -// CHECK: %[[MAX_DIM2:.*]] = select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index +// CHECK: %[[MAX_DIM2:.*]] = arith.select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index // CHECK: } // CHECK: %[[OUT_BOUND_0:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index // CHECK: %[[REDUCTION_0:.*]] = scf.if %[[OUT_BOUND_0]] -> (i1) { @@ -548,9 +548,9 @@ // CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex> // CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex> // CHECK: %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index -// CHECK: %[[LARGER_DIM:.*]] = select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index +// CHECK: %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index // CHECK: %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index -// CHECK: %[[MAX_RANK:.*]] = select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index +// CHECK: %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index // CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index // CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index // CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index @@ -564,7 +564,7 @@ // CHECK: %[[IDX0:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF0]] : index // CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex> // CHECK: %[[DIM0_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_0:.*]], %[[C1]] : index -// CHECK: %[[MAX_DIM0:.*]] = select %[[DIM0_IS_1]], %[[C1]], %[[EXTRACTED_0]] : index +// CHECK: %[[MAX_DIM0:.*]] = arith.select %[[DIM0_IS_1]], %[[C1]], %[[EXTRACTED_0]] : index // CHECK: } // CHECK: %[[VAL_28:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index // CHECK: %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) { @@ -573,7 +573,7 @@ // CHECK: %[[IDX1:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF1]] : index // CHECK: %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex> // CHECK: %[[DIM1_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_1:.*]], %[[C1]] : index -// CHECK: %[[MAX_DIM1:.*]] = select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index +// CHECK: %[[MAX_DIM1:.*]] = arith.select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index // CHECK: } // CHECK: %[[VAL_36:.*]] = arith.cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index // CHECK: %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) { @@ -582,7 +582,7 @@ // CHECK: %[[IDX2:.*]] = arith.subi %[[IDX]], %[[DIM_DIFF2]] : index // CHECK: %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex> // CHECK: %[[DIM2_IS_1:.*]] = arith.cmpi eq, %[[EXTRACTED_2:.*]], %[[C1]] : index -// CHECK: %[[MAX_DIM2:.*]] = select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index +// CHECK: %[[MAX_DIM2:.*]] = arith.select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index // CHECK: } // CHECK: tensor.yield %[[DIM2]] : index // CHECK: } : tensor @@ -614,7 +614,7 @@ // CHECK-NEXT: %[[RANK:.*]] = tensor.dim %[[SHAPE]], %[[C0]] : tensor // CHECK-NEXT: %[[POSINDEX:.*]] = arith.addi %[[INDEX]], %[[RANK]] : index // CHECK-NEXT: %[[ISNEG:.*]] = arith.cmpi slt, %[[INDEX]], %[[C0]] : index - // CHECK-NEXT: %[[SELECT:.*]] = select %[[ISNEG]], %[[POSINDEX]], %[[INDEX]] : index + // CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[ISNEG]], %[[POSINDEX]], %[[INDEX]] : index // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index // CHECK-NEXT: %[[HEAD:.*]] = tensor.extract_slice %[[SHAPE]][%[[C0]]] [%[[SELECT]]] [%[[C1]]] : tensor to tensor // CHECK-NEXT: %[[TAIL_SIZE:.*]] = arith.subi %[[RANK]], %[[SELECT]] : index diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -427,13 +427,6 @@ return } -// CHECK-LABEL: @select -func @select(%arg0 : i1, %arg1 : i32, %arg2 : i32) -> i32 { -// CHECK: = llvm.select %arg0, %arg1, %arg2 : i1, i32 - %0 = select %arg0, %arg1, %arg2 : i32 - return %0 : i32 -} - // CHECK-LABEL: @dfs_block_order func @dfs_block_order(%arg0: i32) -> (i32) { // CHECK-NEXT: %[[CST:.*]] = llvm.mlir.constant(42 : i32) : i32 @@ -519,19 +512,6 @@ // ----- -// CHECK-LABEL: func @select_2dvector( -func @select_2dvector(%arg0 : vector<4x3xi1>, %arg1 : vector<4x3xi32>, %arg2 : vector<4x3xi32>) { - // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<3xi1>> - // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %arg1[0] : !llvm.array<4 x vector<3xi32>> - // CHECK: %[[EXTRACT3:.*]] = llvm.extractvalue %arg2[0] : !llvm.array<4 x vector<3xi32>> - // CHECK: %[[SELECT:.*]] = llvm.select %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]] : vector<3xi1>, vector<3xi32> - // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[SELECT]], %0[0] : !llvm.array<4 x vector<3xi32>> - %0 = select %arg0, %arg1, %arg2 : vector<4x3xi1>, vector<4x3xi32> - std.return -} - -// ----- - // CHECK-LABEL: func @switchi8( func @switchi8(%arg0 : i8) -> i32 { switch %arg0 : i8, [ diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -split-input-file -convert-std-to-spirv -verify-diagnostics %s | FileCheck %s //===----------------------------------------------------------------------===// -// std.select +// arith.select //===----------------------------------------------------------------------===// module attributes { @@ -848,24 +848,6 @@ // ----- -module attributes { - spv.target_env = #spv.target_env< - #spv.vce, {}> -} { - -// CHECK-LABEL: @select -func @select(%arg0 : i32, %arg1 : i32) { - %0 = arith.cmpi sle, %arg0, %arg1 : i32 - // CHECK: spv.Select - %1 = select %0, %arg0, %arg1 : i32 - return -} - -} // end module - -// ----- - //===----------------------------------------------------------------------===// // std.return //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -228,28 +228,28 @@ // CHECK: [[PAD0:%.+]] = arith.constant 1 // CHECK: [[SUBP0:%.+]] = arith.subi [[IDX1]], [[PAD0]] // CHECK: [[P0CMP:%.+]] = arith.cmpi slt, [[SUBP0]], [[ZERO]] - // CHECK: [[SELP0:%.+]] = select [[P0CMP]], [[SUBP0]], [[ZERO]] + // CHECK: [[SELP0:%.+]] = arith.select [[P0CMP]], [[SUBP0]], [[ZERO]] // CHECK: [[ADDP0:%.+]] = arith.addi [[KH]], [[SELP0]] // CHECK: [[PAD1:%.+]] = arith.constant 1 // CHECK: [[SUBP1:%.+]] = arith.subi [[NY]], [[PAD1]] // CHECK: [[P1CMP:%.+]] = arith.cmpi slt, [[SUBP1]], [[ZERO]] - // CHECK: [[SELP1:%.+]] = select [[P1CMP]], [[SUBP1]], [[ZERO]] + // CHECK: [[SELP1:%.+]] = arith.select [[P1CMP]], [[SUBP1]], [[ZERO]] // CHECK: [[ADDP1:%.+]] = arith.addi [[ADDP0]], [[SELP1]] // CHECK: [[YCMP:%.+]] = arith.cmpi slt, [[ADDP1]], [[ONE]] - // CHECK: [[YSEL:%.+]] = select [[YCMP]], [[ONE]], [[ADDP1]] + // CHECK: [[YSEL:%.+]] = arith.select [[YCMP]], [[ONE]], [[ADDP1]] // CHECK: [[KW:%.+]] = arith.constant 4 : index // CHECK: [[PAD2:%.+]] = arith.constant 1 : index // CHECK: [[SUBP2:%.+]] = arith.subi [[IDX2]], [[PAD2]] // CHECK: [[P2CMP:%.+]] = arith.cmpi slt, [[SUBP2]], [[ZERO]] - // CHECK: [[SELP2:%.+]] = select [[P2CMP]], [[SUBP2]], [[ZERO]] + // CHECK: [[SELP2:%.+]] = arith.select [[P2CMP]], [[SUBP2]], [[ZERO]] // CHECK: [[ADDP2:%.+]] = arith.addi [[KW]], [[SELP2]] // CHECK: [[PAD3:%.+]] = arith.constant 1 : index // CHECK: [[SUBP3:%.+]] = arith.subi [[NX]], [[PAD3]] // CHECK: [[P3CMP:%.+]] = arith.cmpi slt, [[SUBP3]], [[ZERO]] - // CHECK: [[SELP3:%.+]] = select [[P3CMP]], [[SUBP3]], [[ZERO]] + // CHECK: [[SELP3:%.+]] = arith.select [[P3CMP]], [[SUBP3]], [[ZERO]] // CHECK: [[ADDP3:%.+]] = arith.addi [[ADDP2]], [[SELP3]] // CHECK: [[XCMP:%.+]] = arith.cmpi slt, [[ADDP3]], [[ONE]] - // CHECK: [[XSEL:%.+]] = select [[XCMP]], [[ONE]], [[ADDP3]] + // CHECK: [[XSEL:%.+]] = arith.select [[XCMP]], [[ONE]], [[ADDP3]] // Given the valid coverage of the pooling region, normalize the summation. // CHECK: [[C:%.+]] = arith.muli [[YSEL]], [[XSEL]] @@ -299,9 +299,9 @@ // CHECK: %[[MIN:.+]] = arith.constant -128 // CHECK: %[[MAX:.+]] = arith.constant 127 // CHECK: %[[CMP_MIN:.+]] = arith.cmpi slt, %[[OUT]], %[[MIN]] - // CHECK: %[[CLMP_MIN:.+]] = select %[[CMP_MIN]], %[[MIN]], %[[OUT]] + // CHECK: %[[CLMP_MIN:.+]] = arith.select %[[CMP_MIN]], %[[MIN]], %[[OUT]] // CHECK: %[[CMP_MAX:.+]] = arith.cmpi slt, %[[MAX]], %[[OUT]] - // CHECK: %[[CLMP_MAX:.+]] = select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]] + // CHECK: %[[CLMP_MAX:.+]] = arith.select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]] // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLMP_MAX]] // CHECK: linalg.yield %[[TRUNC]] %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi8> @@ -328,9 +328,9 @@ // CHECK: %[[MIN:.+]] = arith.constant -32768 // CHECK: %[[MAX:.+]] = arith.constant 32767 // CHECK: %[[CMP_MIN:.+]] = arith.cmpi slt, %[[OUT]], %[[MIN]] - // CHECK: %[[CLMP_MIN:.+]] = select %[[CMP_MIN]], %[[MIN]], %[[OUT]] + // CHECK: %[[CLMP_MIN:.+]] = arith.select %[[CMP_MIN]], %[[MIN]], %[[OUT]] // CHECK: %[[CMP_MAX:.+]] = arith.cmpi slt, %[[MAX]], %[[OUT]] - // CHECK: %[[CLMP_MAX:.+]] = select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]] + // CHECK: %[[CLMP_MAX:.+]] = arith.select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]] // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLMP_MAX]] // CHECK: linalg.yield %[[TRUNC]] %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi16>) -> tensor<1x32x32x2xi16> diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -459,18 +459,18 @@ // CHECK-DAG: %[[C127:.+]] = arith.constant -127 // CHECK-DAG: %[[C126:.+]] = arith.constant 126 // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %arg1, %[[C127]] - // CHECK-DAG: %[[SEL1:.+]] = select %[[CMP1]], %[[C127]] + // CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C127]] // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C126]], %arg1 - // CHECK: %[[SEL2:.+]] = select %[[CMP2]], %[[C126]], %[[SEL1]] + // CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C126]], %[[SEL1]] %0 = "tosa.clamp"(%arg0) {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8> // CHECK: linalg.generic // CHECK-DAG: %[[C128:.+]] = arith.constant -128 // CHECK-DAG: %[[C127:.+]] = arith.constant 127 // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %arg1, %[[C128]] - // CHECK-DAG: %[[SEL1:.+]] = select %[[CMP1]], %[[C128]] + // CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C128]] // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C127]], %arg1 - // CHECK: %[[SEL2:.+]] = select %[[CMP2]], %[[C127]], %[[SEL1]] + // CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C127]], %[[SEL1]] %1 = "tosa.clamp"(%arg0) {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8> return @@ -511,9 +511,9 @@ // CHECK: [[MIN:%.+]] = arith.constant -128 // CHECK: [[MAX:%.+]] = arith.constant 127 // CHECK: [[PRED1:%.+]] = arith.cmpi slt, [[SUB]], [[MIN]] - // CHECK: [[LBOUND:%.+]] = select [[PRED1]], [[MIN]], [[SUB]] + // CHECK: [[LBOUND:%.+]] = arith.select [[PRED1]], [[MIN]], [[SUB]] // CHECK: [[PRED2:%.+]] = arith.cmpi slt, [[MAX]], [[SUB]] - // CHECK: [[UBOUND:%.+]] = select [[PRED2]], [[MAX]], [[LBOUND]] + // CHECK: [[UBOUND:%.+]] = arith.select [[PRED2]], [[MAX]], [[LBOUND]] // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]] // CHECK: linalg.yield [[TRUNC]] %0 = "tosa.negate"(%arg0) {quantization_info = { input_zp = 0 : i32, output_zp = 0 : i32}} : (tensor<1xi8>) -> tensor<1xi8> @@ -793,7 +793,7 @@ // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor) outs(%[[FILL]] : tensor) // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: %[[CMP:.+]] = arith.cmpf ogt, %arg1, %arg2 : f32 - // CHECK: %[[RES:.+]] = select %[[CMP]], %arg1, %arg2 : f32 + // CHECK: %[[RES:.+]] = arith.select %[[CMP]], %arg1, %arg2 : f32 // CHECK: linalg.yield %[[RES]] : f32 // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor into tensor %0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor) -> tensor @@ -973,8 +973,8 @@ // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127 // CHECK-DAG: [[MINLT:%.+]] = arith.cmpi slt, [[SCALED_ZEROED]], [[CMIN]] // CHECK-DAG: [[MAXLT:%.+]] = arith.cmpi slt, [[CMAX]], [[SCALED_ZEROED]] - // CHECK-DAG: [[LOWER:%.+]] = select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]] - // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]] + // CHECK-DAG: [[LOWER:%.+]] = arith.select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]] + // CHECK-DAG: [[BOUNDED:%.+]] = arith.select [[MAXLT]], [[CMAX]], [[LOWER]] // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]] // CHECK-DAG: linalg.yield [[TRUNC]] %0 = "tosa.rescale"(%arg0) {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> (tensor<2xi8>) @@ -993,9 +993,9 @@ // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0 // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255 // CHECK-DAG: [[MINLT:%.+]] = arith.cmpi slt, [[SCALED_ZEROED]], [[CMIN]] - // CHECK-DAG: [[LOWER:%.+]] = select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]] + // CHECK-DAG: [[LOWER:%.+]] = arith.select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]] // CHECK-DAG: [[MAXLT:%.+]] = arith.cmpi slt, [[CMAX]], [[SCALED_ZEROED]] - // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]] + // CHECK-DAG: [[BOUNDED:%.+]] = arith.select [[MAXLT]], [[CMAX]], [[LOWER]] // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]] // CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8 // CHECK: linalg.yield [[CAST]] @@ -1046,9 +1046,9 @@ // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128 // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127 // CHECK-DAG: [[MINLT:%.+]] = arith.cmpi slt, [[SCALED_ZEROED]], [[CMIN]] - // CHECK-DAG: [[LOWER:%.+]] = select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]] + // CHECK-DAG: [[LOWER:%.+]] = arith.select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]] // CHECK-DAG: [[MAXLT:%.+]] = arith.cmpi slt, [[CMAX]], [[SCALED_ZEROED]] - // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]] + // CHECK-DAG: [[BOUNDED:%.+]] = arith.select [[MAXLT]], [[CMAX]], [[LOWER]] // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]] // CHECK: linalg.yield [[TRUNC]] %0 = "tosa.rescale"(%arg0) {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<2xui8>) -> (tensor<2xi8>) @@ -1078,8 +1078,8 @@ // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127 // CHECK-DAG: [[MINLT:%.+]] = arith.cmpi slt, [[SCALED_ZEROED]], [[CMIN]] // CHECK-DAG: [[MAXLT:%.+]] = arith.cmpi slt, [[CMAX]], [[SCALED_ZEROED]] - // CHECK-DAG: [[LOWER:%.+]] = select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]] - // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]] + // CHECK-DAG: [[LOWER:%.+]] = arith.select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]] + // CHECK-DAG: [[BOUNDED:%.+]] = arith.select [[MAXLT]], [[CMAX]], [[LOWER]] // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]] // CHECK-DAG: linalg.yield [[TRUNC]] %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [42 : i32, 43 : i32, 44 : i32], shift = [14 : i32, 15 : i32, 64 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<3xi8>) -> (tensor<3xi8>) @@ -1338,8 +1338,8 @@ // CHECK: [[IDX:%.+]] = linalg.index 0 // CHECK: [[CAST:%.+]] = arith.index_cast [[IDX]] // CHECK: [[CMP:%.+]] = arith.cmpi sgt, %arg2, %arg4 - // CHECK: [[SELECT_VAL:%.+]] = select [[CMP]], %arg2, %arg4 - // CHECK: [[SELECT_IDX:%.+]] = select [[CMP]], [[CAST]], %arg3 + // CHECK: [[SELECT_VAL:%.+]] = arith.select [[CMP]], %arg2, %arg4 + // CHECK: [[SELECT_IDX:%.+]] = arith.select [[CMP]], [[CAST]], %arg3 // CHECK: linalg.yield [[SELECT_IDX]], [[SELECT_VAL]] %0 = "tosa.argmax"(%arg0) { axis = 0 : i64} : (tensor<3x2xi32>) -> (tensor<2xi32>) @@ -1353,8 +1353,8 @@ // CHECK: [[IDX:%.+]] = linalg.index 1 // CHECK: [[CAST:%.+]] = arith.index_cast [[IDX]] // CHECK: [[CMP:%.+]] = arith.cmpi sgt, %arg2, %arg4 - // CHECK: [[SELECT_VAL:%.+]] = select [[CMP]], %arg2, %arg4 - // CHECK: [[SELECT_IDX:%.+]] = select [[CMP]], [[CAST]], %arg3 + // CHECK: [[SELECT_VAL:%.+]] = arith.select [[CMP]], %arg2, %arg4 + // CHECK: [[SELECT_IDX:%.+]] = arith.select [[CMP]], [[CAST]], %arg3 // CHECK: linalg.yield [[SELECT_IDX]], [[SELECT_VAL]] %1 = "tosa.argmax"(%arg0) { axis = 1 : i64} : (tensor<3x2xi32>) -> (tensor<3xi32>) @@ -1388,8 +1388,8 @@ // CHECK: %[[IDX:.+]] = linalg.index 0 // CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX]] // CHECK: %[[CMP:.+]] = arith.cmpi sgt, %arg1, %arg3 - // CHECK: %[[SELECT_VAL:.+]] = select %[[CMP]], %arg1, %arg3 - // CHECK: %[[SELECT_IDX:.+]] = select %[[CMP]], %[[CAST]], %arg2 + // CHECK: %[[SELECT_VAL:.+]] = arith.select %[[CMP]], %arg1, %arg3 + // CHECK: %[[SELECT_IDX:.+]] = arith.select %[[CMP]], %[[CAST]], %arg2 // CHECK: linalg.yield %[[SELECT_IDX]], %[[SELECT_VAL]] %0 = "tosa.argmax"(%arg0) { axis = 0 : i64} : (tensor<3x?xi32>) -> (tensor) return @@ -1411,8 +1411,8 @@ // CHECK: %[[IDX:.+]] = linalg.index 1 // CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX]] // CHECK: %[[CMP:.+]] = arith.cmpi sgt, %arg1, %arg3 - // CHECK: %[[SELECT_VAL:.+]] = select %[[CMP]], %arg1, %arg3 - // CHECK: %[[SELECT_IDX:.+]] = select %[[CMP]], %[[CAST]], %arg2 + // CHECK: %[[SELECT_VAL:.+]] = arith.select %[[CMP]], %arg1, %arg3 + // CHECK: %[[SELECT_IDX:.+]] = arith.select %[[CMP]], %[[CAST]], %arg2 // CHECK: linalg.yield %[[SELECT_IDX]], %[[SELECT_VAL]] %0 = "tosa.argmax"(%arg0) { axis = 1 : i64} : (tensor<3x?xi32>) -> (tensor<3xi32>) return @@ -1587,21 +1587,21 @@ // CHECK-DAG: %[[VAL17:.+]] = arith.cmpf oge, %[[VAL13]], %[[ROUND]] // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0 // CHECK-DAG: %[[ONE:.+]] = arith.constant 1 - // CHECK-DAG: %[[VAL18:.+]] = select %[[VAL16]], %[[ONE]], %[[ZERO]] - // CHECK-DAG: %[[VAL19:.+]] = select %[[VAL17]], %[[ONE]], %[[ZERO]] + // CHECK-DAG: %[[VAL18:.+]] = arith.select %[[VAL16]], %[[ONE]], %[[ZERO]] + // CHECK-DAG: %[[VAL19:.+]] = arith.select %[[VAL17]], %[[ONE]], %[[ZERO]] // CHECK-DAG: %[[VAL20:.+]] = arith.addi %[[VAL14]], %[[VAL18]] // CHECK-DAG: %[[VAL21:.+]] = arith.addi %[[VAL15]], %[[VAL19]] // This section applies bound checking to be within the input image. // CHECK-DAG: %[[VAL22:.+]] = arith.cmpi slt, %[[VAL20]], %[[XYMIN]] - // CHECK-DAG: %[[VAL23:.+]] = select %[[VAL22]], %[[XYMIN]], %[[VAL20]] + // CHECK-DAG: %[[VAL23:.+]] = arith.select %[[VAL22]], %[[XYMIN]], %[[VAL20]] // CHECK-DAG: %[[VAL24:.+]] = arith.cmpi slt, %[[YMAX]], %[[VAL20]] - // CHECK-DAG: %[[VAL25:.+]] = select %[[VAL24]], %[[YMAX]], %[[VAL23]] + // CHECK-DAG: %[[VAL25:.+]] = arith.select %[[VAL24]], %[[YMAX]], %[[VAL23]] // CHECK-DAG: %[[VAL26:.+]] = arith.cmpi slt, %[[VAL21]], %[[XYMIN]] - // CHECK-DAG: %[[VAL27:.+]] = select %[[VAL26]], %[[XYMIN]], %[[VAL21]] + // CHECK-DAG: %[[VAL27:.+]] = arith.select %[[VAL26]], %[[XYMIN]], %[[VAL21]] // CHECK-DAG: %[[VAL28:.+]] = arith.cmpi slt, %[[XMAX]], %[[VAL21]] - // CHECK-DAG: %[[VAL29:.+]] = select %[[VAL28]], %[[XMAX]], %[[VAL27]] + // CHECK-DAG: %[[VAL29:.+]] = arith.select %[[VAL28]], %[[XMAX]], %[[VAL27]] // Extract the nearest value using the computed indices. @@ -1646,24 +1646,24 @@ // Bound check each dimension. // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[Y0]], %[[XYMIN]] - // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[Y0]] + // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[Y0]] // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[YMAX]], %[[Y0]] - // CHECK: %[[YLO:.+]] = select %[[PRED]], %[[YMAX]], %[[BOUND]] + // CHECK: %[[YLO:.+]] = arith.select %[[PRED]], %[[YMAX]], %[[BOUND]] // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[Y1]], %[[XYMIN]] - // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[Y1]] + // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[Y1]] // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[YMAX]], %[[Y1]] - // CHECK: %[[YHI:.+]] = select %[[PRED]], %[[YMAX]], %[[BOUND]] + // CHECK: %[[YHI:.+]] = arith.select %[[PRED]], %[[YMAX]], %[[BOUND]] // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[X0]], %[[XYMIN]] - // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[X0]] + // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[X0]] // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[XMAX]], %[[X0]] - // CHECK: %[[XLO:.+]] = select %[[PRED]], %[[XMAX]], %[[BOUND]] + // CHECK: %[[XLO:.+]] = arith.select %[[PRED]], %[[XMAX]], %[[BOUND]] // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[X1]], %[[XYMIN]] - // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[X1]] + // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[X1]] // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[XMAX]], %[[X1]] - // CHECK: %[[XHI:.+]] = select %[[PRED]], %[[XMAX]], %[[BOUND]] + // CHECK: %[[XHI:.+]] = arith.select %[[PRED]], %[[XMAX]], %[[BOUND]] // Extract each corner of the bilinear interpolation. @@ -1738,21 +1738,21 @@ // CHECK-DAG: %[[VAL17:.+]] = arith.cmpi sge, %[[VAL13]], %[[ROUND]] // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0 // CHECK-DAG: %[[ONE:.+]] = arith.constant 1 - // CHECK-DAG: %[[VAL18:.+]] = select %[[VAL16]], %[[ONE]], %[[ZERO]] - // CHECK-DAG: %[[VAL19:.+]] = select %[[VAL17]], %[[ONE]], %[[ZERO]] + // CHECK-DAG: %[[VAL18:.+]] = arith.select %[[VAL16]], %[[ONE]], %[[ZERO]] + // CHECK-DAG: %[[VAL19:.+]] = arith.select %[[VAL17]], %[[ONE]], %[[ZERO]] // CHECK-DAG: %[[VAL20:.+]] = arith.addi %[[VAL8]], %[[VAL18]] // CHECK-DAG: %[[VAL21:.+]] = arith.addi %[[VAL9]], %[[VAL19]] // This section applies bound checking to be within the input image. // CHECK-DAG: %[[VAL22:.+]] = arith.cmpi slt, %[[VAL20]], %[[XYMIN]] - // CHECK-DAG: %[[VAL23:.+]] = select %[[VAL22]], %[[XYMIN]], %[[VAL20]] + // CHECK-DAG: %[[VAL23:.+]] = arith.select %[[VAL22]], %[[XYMIN]], %[[VAL20]] // CHECK-DAG: %[[VAL24:.+]] = arith.cmpi slt, %[[YMAX]], %[[VAL20]] - // CHECK-DAG: %[[VAL25:.+]] = select %[[VAL24]], %[[YMAX]], %[[VAL23]] + // CHECK-DAG: %[[VAL25:.+]] = arith.select %[[VAL24]], %[[YMAX]], %[[VAL23]] // CHECK-DAG: %[[VAL26:.+]] = arith.cmpi slt, %[[VAL21]], %[[XYMIN]] - // CHECK-DAG: %[[VAL27:.+]] = select %[[VAL26]], %[[XYMIN]], %[[VAL21]] + // CHECK-DAG: %[[VAL27:.+]] = arith.select %[[VAL26]], %[[XYMIN]], %[[VAL21]] // CHECK-DAG: %[[VAL28:.+]] = arith.cmpi slt, %[[XMAX]], %[[VAL21]] - // CHECK-DAG: %[[VAL29:.+]] = select %[[VAL28]], %[[XMAX]], %[[VAL27]] + // CHECK-DAG: %[[VAL29:.+]] = arith.select %[[VAL28]], %[[XMAX]], %[[VAL27]] // Extract the nearest value using the computed indices. @@ -1794,24 +1794,24 @@ // Bound check each dimension. // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[Y0]], %[[XYMIN]] - // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[Y0]] + // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[Y0]] // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[YMAX]], %[[Y0]] - // CHECK: %[[YLO:.+]] = select %[[PRED]], %[[YMAX]], %[[BOUND]] + // CHECK: %[[YLO:.+]] = arith.select %[[PRED]], %[[YMAX]], %[[BOUND]] // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[Y1]], %[[XYMIN]] - // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[Y1]] + // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[Y1]] // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[YMAX]], %[[Y1]] - // CHECK: %[[YHI:.+]] = select %[[PRED]], %[[YMAX]], %[[BOUND]] + // CHECK: %[[YHI:.+]] = arith.select %[[PRED]], %[[YMAX]], %[[BOUND]] // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[X0]], %[[XYMIN]] - // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[X0]] + // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[X0]] // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[XMAX]], %[[X0]] - // CHECK: %[[XLO:.+]] = select %[[PRED]], %[[XMAX]], %[[BOUND]] + // CHECK: %[[XLO:.+]] = arith.select %[[PRED]], %[[XMAX]], %[[BOUND]] // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[X1]], %[[XYMIN]] - // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[X1]] + // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[X1]] // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[XMAX]], %[[X1]] - // CHECK: %[[XHI:.+]] = select %[[PRED]], %[[XMAX]], %[[BOUND]] + // CHECK: %[[XHI:.+]] = arith.select %[[PRED]], %[[XMAX]], %[[BOUND]] // Extract each corner of the bilinear interpolation. diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir --- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir +++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir @@ -37,10 +37,10 @@ // CHECK-DAG: [[POSITIVE_ROUND:%.+]] = arith.addi [[SHIFTED_64]], [[SECOND_BIAS_64]] // CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = arith.subi [[SHIFTED_64]], [[SECOND_BIAS_64]] // CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = arith.cmpi sge, %arg0, [[C0_32]] : i32 - // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64 + // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = arith.select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64 // CHECK-DAG: [[C32_32:%.+]] = arith.constant 32 : i32 // CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = arith.cmpi sge, [[SHIFT_32]], [[C32_32]] - // CHECK-DAG: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]] + // CHECK-DAG: [[ROUND:%.+]] = arith.select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]] // CHECK-DAG: [[VAL_64:%.+]] = arith.extsi %arg0 : i32 to i64 // CHECK-DAG: [[MULTIPLY_64:%.+]] = arith.extsi %arg1 : i32 to i64 @@ -74,10 +74,10 @@ // CHECK-DAG: [[POSITIVE_ROUND:%.+]] = arith.addi [[SHIFTED_64]], [[SECOND_BIAS_64]] // CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = arith.subi [[SHIFTED_64]], [[SECOND_BIAS_64]] // CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = arith.cmpi sge, %arg0, [[C0_32]] : vector<4xi32> - // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : vector<4xi1>, vector<4xi64> + // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = arith.select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : vector<4xi1>, vector<4xi64> // CHECK-DAG: [[C32_32:%.+]] = arith.constant dense<32> : vector<4xi32> // CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = arith.cmpi sge, [[SHIFT_32]], [[C32_32]] - // CHECK-DAG: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]] + // CHECK-DAG: [[ROUND:%.+]] = arith.select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]] // CHECK-DAG: [[VAL_64:%.+]] = arith.extsi %arg0 : vector<4xi32> to vector<4xi64> // CHECK-DAG: [[MULTIPLY_64:%.+]] = arith.extsi %arg1 : vector<4xi32> to vector<4xi64> @@ -110,9 +110,9 @@ // CHECK-DAG: [[POSITIVE_ROUND:%.+]] = arith.addi [[SHIFTED_64]], [[SECOND_BIAS_64]] // CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = arith.subi [[SHIFTED_64]], [[SECOND_BIAS_64]] // CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = arith.cmpi sge, %arg0, [[C0_32]] : i48 - // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64 + // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = arith.select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64 // CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = arith.cmpi sge, [[SHIFT_32]], [[C32_32]] - // CHECK-DAG: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]] + // CHECK-DAG: [[ROUND:%.+]] = arith.select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]] // CHECK-DAG: [[VAL_64:%.+]] = arith.extsi %arg0 : i48 to i64 // CHECK-DAG: [[MULTIPLY_64:%.+]] = arith.extsi %arg1 : i32 to i64 // CHECK-DAG: [[SHIFT_64:%.+]] = arith.extsi %arg2 : i8 to i64 diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir --- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir +++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir @@ -476,7 +476,7 @@ // CHECK: %[[mask:.*]] = vector.create_mask %[[elems_left]] : vector<128xi1> // CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> // CHECK: %[[add:.*]] = arith.addf %[[red_iter]], %[[ld]] : vector<128xf32> -// CHECK: %[[new_acc:.*]] = select %[[mask]], %[[add]], %[[red_iter]] : vector<128xi1>, vector<128xf32> +// CHECK: %[[new_acc:.*]] = arith.select %[[mask]], %[[add]], %[[red_iter]] : vector<128xi1>, vector<128xf32> // CHECK: affine.yield %[[new_acc]] : vector<128xf32> // CHECK: } // CHECK: %[[final_sum:.*]] = vector.reduction "add", %[[vred:.*]] : vector<128xf32> into f32 @@ -509,7 +509,7 @@ // CHECK: %[[mask:.*]] = vector.create_mask %[[elems_left]] : vector<128xi1> // CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> // CHECK: %[[add:.*]] = arith.addf %[[red_iter]], %[[ld]] : vector<128xf32> -// CHECK: %[[new_acc:.*]] = select %[[mask]], %[[add]], %[[red_iter]] : vector<128xi1>, vector<128xf32> +// CHECK: %[[new_acc:.*]] = arith.select %[[mask]], %[[add]], %[[red_iter]] : vector<128xi1>, vector<128xf32> // CHECK: affine.yield %[[new_acc]] : vector<128xf32> // CHECK: } // CHECK: %[[final_sum:.*]] = vector.reduction "add", %[[vred:.*]] : vector<128xf32> into f32 @@ -562,7 +562,7 @@ // CHECK: %[[mask:.*]] = vector.create_mask %[[elems_left]] : vector<128xi1> // CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> // CHECK: %[[add:.*]] = arith.addf %[[red_iter]], %[[ld]] : vector<128xf32> -// CHECK: %[[new_acc:.*]] = select %[[mask]], %[[add]], %[[red_iter]] : vector<128xi1>, vector<128xf32> +// CHECK: %[[new_acc:.*]] = arith.select %[[mask]], %[[add]], %[[red_iter]] : vector<128xi1>, vector<128xf32> // CHECK: affine.yield %[[new_acc]] : vector<128xf32> // ----- @@ -591,7 +591,7 @@ // CHECK: %[[mask:.*]] = vector.create_mask %[[elems_left]] : vector<128xi1> // CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> // CHECK: %[[add:.*]] = arith.addf %[[red_iter]], %[[ld]] : vector<128xf32> -// CHECK: %[[new_acc:.*]] = select %[[mask]], %[[add]], %[[red_iter]] : vector<128xi1>, vector<128xf32> +// CHECK: %[[new_acc:.*]] = arith.select %[[mask]], %[[add]], %[[red_iter]] : vector<128xi1>, vector<128xf32> // CHECK: affine.yield %[[new_acc]] : vector<128xf32> // ----- @@ -624,7 +624,7 @@ // CHECK: %[[exp:.*]] = math.exp %[[ld]] : vector<128xf32> // CHECK: %[[add:.*]] = arith.addf %[[sum_iter]], %[[ld]] : vector<128xf32> // CHECK: %[[eadd:.*]] = arith.addf %[[esum_iter]], %[[exp]] : vector<128xf32> -// CHECK: %[[new_acc:.*]] = select %[[mask]], %[[add]], %[[sum_iter]] : vector<128xi1>, vector<128xf32> -// CHECK: %[[new_eacc:.*]] = select %[[mask]], %[[eadd]], %[[esum_iter]] : vector<128xi1>, vector<128xf32> +// CHECK: %[[new_acc:.*]] = arith.select %[[mask]], %[[add]], %[[sum_iter]] : vector<128xi1>, vector<128xf32> +// CHECK: %[[new_eacc:.*]] = arith.select %[[mask]], %[[eadd]], %[[esum_iter]] : vector<128xi1>, vector<128xf32> // CHECK: affine.yield %[[new_acc]], %[[new_eacc]] : vector<128xf32> // CHECK: } diff --git a/mlir/test/Dialect/Affine/parallelize.mlir b/mlir/test/Dialect/Affine/parallelize.mlir --- a/mlir/test/Dialect/Affine/parallelize.mlir +++ b/mlir/test/Dialect/Affine/parallelize.mlir @@ -27,7 +27,7 @@ %2 = affine.load %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> %3 = affine.load %1[%arg0 + %arg4, %arg1 * 2 + %arg5, %arg2 * 2 + %arg6, %arg3 + %arg7] : memref<1x18x18x64xf32> %4 = arith.cmpf ogt, %2, %3 : f32 - %5 = select %4, %2, %3 : f32 + %5 = arith.select %4, %2, %3 : f32 affine.store %5, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32> } } @@ -63,7 +63,7 @@ // CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32> // CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] : memref<1x18x18x64xf32> // CHECK: %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32 -// CHECK: %[[sel:.*]] = select %[[res]], %[[lhs]], %[[rhs]] : f32 +// CHECK: %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32 // CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32> // CHECK: } // CHECK: } diff --git a/mlir/test/Dialect/Arithmetic/bufferize.mlir b/mlir/test/Dialect/Arithmetic/bufferize.mlir --- a/mlir/test/Dialect/Arithmetic/bufferize.mlir +++ b/mlir/test/Dialect/Arithmetic/bufferize.mlir @@ -80,3 +80,19 @@ } // CHECK: } + +// ----- + +// CHECK-LABEL: func @select( +// CHECK-SAME: %[[PRED:.*]]: i1, +// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor, +// CHECK-SAME: %[[FALSE_VAL:.*]]: tensor) -> tensor { +// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : memref +// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : memref +// CHECK: %[[RET_MEMREF:.*]] = arith.select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref +// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[RET_MEMREF]] : memref +// CHECK: return %[[RET]] : tensor +func @select(%arg0: i1, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = arith.select %arg0, %arg1, %arg2 : tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Arithmetic/expand-ops.mlir b/mlir/test/Dialect/Arithmetic/expand-ops.mlir --- a/mlir/test/Dialect/Arithmetic/expand-ops.mlir +++ b/mlir/test/Dialect/Arithmetic/expand-ops.mlir @@ -11,7 +11,7 @@ // CHECK: [[ZERO:%.+]] = arith.constant 0 : i32 // CHECK: [[MINONE:%.+]] = arith.constant -1 : i32 // CHECK: [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32 -// CHECK: [[X:%.+]] = select [[CMP1]], [[MINONE]], [[ONE]] : i32 +// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : i32 // CHECK: [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : i32 // CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32 // CHECK: [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : i32 @@ -25,7 +25,7 @@ // CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1 // CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1 // CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1 -// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32 +// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32 } // ----- @@ -41,7 +41,7 @@ // CHECK: [[ZERO:%.+]] = arith.constant 0 : index // CHECK: [[MINONE:%.+]] = arith.constant -1 : index // CHECK: [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index -// CHECK: [[X:%.+]] = select [[CMP1]], [[MINONE]], [[ONE]] : index +// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : index // CHECK: [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : index // CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index // CHECK: [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : index @@ -55,7 +55,7 @@ // CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1 // CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1 // CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1 -// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE3]] : index +// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : index } // ----- @@ -70,7 +70,7 @@ // CHECK: [[ZERO:%.+]] = arith.constant 0 : i32 // CHECK: [[MIN1:%.+]] = arith.constant -1 : i32 // CHECK: [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32 -// CHECK: [[X:%.+]] = select [[CMP1]], [[ONE]], [[MIN1]] : i32 +// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[ONE]], [[MIN1]] : i32 // CHECK: [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : i32 // CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32 // CHECK: [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : i32 @@ -82,7 +82,7 @@ // CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1 // CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1 // CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1 -// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : i32 +// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : i32 } // ----- @@ -97,7 +97,7 @@ // CHECK: [[ZERO:%.+]] = arith.constant 0 : index // CHECK: [[MIN1:%.+]] = arith.constant -1 : index // CHECK: [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index -// CHECK: [[X:%.+]] = select [[CMP1]], [[ONE]], [[MIN1]] : index +// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[ONE]], [[MIN1]] : index // CHECK: [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : index // CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index // CHECK: [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : index @@ -109,7 +109,7 @@ // CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1 // CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1 // CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1 -// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : index +// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : index } // ----- @@ -126,7 +126,7 @@ // CHECK: [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : i32 // CHECK: [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : i32 // CHECK: [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : i32 -// CHECK: [[RES:%.+]] = select [[ISZERO]], [[ZERO]], [[REM]] : i32 +// CHECK: [[RES:%.+]] = arith.select [[ISZERO]], [[ZERO]], [[REM]] : i32 } // ----- @@ -143,7 +143,7 @@ // CHECK: [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : index // CHECK: [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : index // CHECK: [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : index -// CHECK: [[RES:%.+]] = select [[ISZERO]], [[ZERO]], [[REM]] : index +// CHECK: [[RES:%.+]] = arith.select [[ISZERO]], [[ZERO]], [[REM]] : index } // ----- @@ -155,9 +155,9 @@ } // CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) // CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32 // CHECK-NEXT: return %[[RESULT]] : f32 // ----- @@ -169,9 +169,9 @@ } // CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>) // CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : vector<4xf16> -// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] +// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] // CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : vector<4xf16> -// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[RHS]], %[[SELECT]] +// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] // CHECK-NEXT: return %[[RESULT]] : vector<4xf16> // ----- @@ -183,9 +183,9 @@ } // CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) // CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ult, %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32 // CHECK-NEXT: return %[[RESULT]] : f32 diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir --- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir +++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir @@ -101,7 +101,7 @@ // CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref // CHECK-SAME: ) { // CHECK: scf.for %[[I:arg[0-9]+]] -// CHECK: select +// CHECK: arith.select // CHECK: scf.for %[[J:arg[0-9]+]] // CHECK: memref.store @@ -122,5 +122,5 @@ // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C10:.*]] = arith.constant 10 : index // CHECK: scf.for %[[I:arg[0-9]+]] -// CHECK-NOT: select +// CHECK-NOT: arith.select // CHECK: scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1 diff --git a/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir b/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir --- a/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir +++ b/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir @@ -18,15 +18,15 @@ // CHECK: %[[scalingCst64:.*]] = arith.constant 6.000000e-01 : f32 // CHECK: %[[workersIndex:.*]] = async.runtime.num_worker_threads : index // CHECK: %[[inBracket4:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound4]] : index - // CHECK: %[[scalingFactor4:.*]] = select %[[inBracket4]], %[[scalingCst4]], %[[scalingCstInit]] : f32 + // CHECK: %[[scalingFactor4:.*]] = arith.select %[[inBracket4]], %[[scalingCst4]], %[[scalingCstInit]] : f32 // CHECK: %[[inBracket8:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound8]] : index - // CHECK: %[[scalingFactor8:.*]] = select %[[inBracket8]], %[[scalingCst8]], %[[scalingFactor4]] : f32 + // CHECK: %[[scalingFactor8:.*]] = arith.select %[[inBracket8]], %[[scalingCst8]], %[[scalingFactor4]] : f32 // CHECK: %[[inBracket16:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound16]] : index - // CHECK: %[[scalingFactor16:.*]] = select %[[inBracket16]], %[[scalingCst16]], %[[scalingFactor8]] : f32 + // CHECK: %[[scalingFactor16:.*]] = arith.select %[[inBracket16]], %[[scalingCst16]], %[[scalingFactor8]] : f32 // CHECK: %[[inBracket32:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound32]] : index - // CHECK: %[[scalingFactor32:.*]] = select %[[inBracket32]], %[[scalingCst32]], %[[scalingFactor16]] : f32 + // CHECK: %[[scalingFactor32:.*]] = arith.select %[[inBracket32]], %[[scalingCst32]], %[[scalingFactor16]] : f32 // CHECK: %[[inBracket64:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound64]] : index - // CHECK: %[[scalingFactor64:.*]] = select %[[inBracket64]], %[[scalingCst64]], %[[scalingFactor32]] : f32 + // CHECK: %[[scalingFactor64:.*]] = arith.select %[[inBracket64]], %[[scalingCst64]], %[[scalingFactor32]] : f32 // CHECK: %[[workersInt:.*]] = arith.index_cast %[[workersIndex]] : index to i32 // CHECK: %[[workersFloat:.*]] = arith.sitofp %[[workersInt]] : i32 to f32 // CHECK: %[[scaledFloat:.*]] = arith.mulf %[[scalingFactor64]], %[[workersFloat]] : f32 diff --git a/mlir/test/Dialect/GPU/all-reduce-max.mlir b/mlir/test/Dialect/GPU/all-reduce-max.mlir --- a/mlir/test/Dialect/GPU/all-reduce-max.mlir +++ b/mlir/test/Dialect/GPU/all-reduce-max.mlir @@ -45,7 +45,7 @@ // CHECK: cond_br [[VAL_35]], ^bb2, ^bb3 // CHECK: ^bb2: // CHECK: [[VAL_36:%.*]] = arith.cmpf ugt, [[VAL_0]], [[VAL_34]] : f32 - // CHECK: [[VAL_37:%.*]] = select [[VAL_36]], [[VAL_0]], [[VAL_34]] : f32 + // CHECK: [[VAL_37:%.*]] = arith.select [[VAL_36]], [[VAL_0]], [[VAL_34]] : f32 // CHECK: br ^bb4([[VAL_37]] : f32) // CHECK: ^bb3: // CHECK: br ^bb4([[VAL_0]] : f32) @@ -54,7 +54,7 @@ // CHECK: cond_br [[VAL_40]], ^bb5, ^bb6 // CHECK: ^bb5: // CHECK: [[VAL_41:%.*]] = arith.cmpf ugt, [[VAL_38]], [[VAL_39]] : f32 - // CHECK: [[VAL_42:%.*]] = select [[VAL_41]], [[VAL_38]], [[VAL_39]] : f32 + // CHECK: [[VAL_42:%.*]] = arith.select [[VAL_41]], [[VAL_38]], [[VAL_39]] : f32 // CHECK: br ^bb7([[VAL_42]] : f32) // CHECK: ^bb6: // CHECK: br ^bb7([[VAL_38]] : f32) @@ -63,7 +63,7 @@ // CHECK: cond_br [[VAL_45]], ^bb8, ^bb9 // CHECK: ^bb8: // CHECK: [[VAL_46:%.*]] = arith.cmpf ugt, [[VAL_43]], [[VAL_44]] : f32 - // CHECK: [[VAL_47:%.*]] = select [[VAL_46]], [[VAL_43]], [[VAL_44]] : f32 + // CHECK: [[VAL_47:%.*]] = arith.select [[VAL_46]], [[VAL_43]], [[VAL_44]] : f32 // CHECK: br ^bb10([[VAL_47]] : f32) // CHECK: ^bb9: // CHECK: br ^bb10([[VAL_43]] : f32) @@ -72,7 +72,7 @@ // CHECK: cond_br [[VAL_50]], ^bb11, ^bb12 // CHECK: ^bb11: // CHECK: [[VAL_51:%.*]] = arith.cmpf ugt, [[VAL_48]], [[VAL_49]] : f32 - // CHECK: [[VAL_52:%.*]] = select [[VAL_51]], [[VAL_48]], [[VAL_49]] : f32 + // CHECK: [[VAL_52:%.*]] = arith.select [[VAL_51]], [[VAL_48]], [[VAL_49]] : f32 // CHECK: br ^bb13([[VAL_52]] : f32) // CHECK: ^bb12: // CHECK: br ^bb13([[VAL_48]] : f32) @@ -81,7 +81,7 @@ // CHECK: cond_br [[VAL_55]], ^bb14, ^bb15 // CHECK: ^bb14: // CHECK: [[VAL_56:%.*]] = arith.cmpf ugt, [[VAL_53]], [[VAL_54]] : f32 - // CHECK: [[VAL_57:%.*]] = select [[VAL_56]], [[VAL_53]], [[VAL_54]] : f32 + // CHECK: [[VAL_57:%.*]] = arith.select [[VAL_56]], [[VAL_53]], [[VAL_54]] : f32 // CHECK: br ^bb16([[VAL_57]] : f32) // CHECK: ^bb15: // CHECK: br ^bb16([[VAL_53]] : f32) @@ -90,19 +90,19 @@ // CHECK: ^bb17: // CHECK: [[VAL_59:%.*]], [[VAL_60:%.*]] = gpu.shuffle xor [[VAL_0]], [[VAL_6]], [[VAL_5]] : f32 // CHECK: [[VAL_61:%.*]] = arith.cmpf ugt, [[VAL_0]], [[VAL_59]] : f32 - // CHECK: [[VAL_62:%.*]] = select [[VAL_61]], [[VAL_0]], [[VAL_59]] : f32 + // CHECK: [[VAL_62:%.*]] = arith.select [[VAL_61]], [[VAL_0]], [[VAL_59]] : f32 // CHECK: [[VAL_63:%.*]], [[VAL_64:%.*]] = gpu.shuffle xor [[VAL_62]], [[VAL_7]], [[VAL_5]] : f32 // CHECK: [[VAL_65:%.*]] = arith.cmpf ugt, [[VAL_62]], [[VAL_63]] : f32 - // CHECK: [[VAL_66:%.*]] = select [[VAL_65]], [[VAL_62]], [[VAL_63]] : f32 + // CHECK: [[VAL_66:%.*]] = arith.select [[VAL_65]], [[VAL_62]], [[VAL_63]] : f32 // CHECK: [[VAL_67:%.*]], [[VAL_68:%.*]] = gpu.shuffle xor [[VAL_66]], [[VAL_8]], [[VAL_5]] : f32 // CHECK: [[VAL_69:%.*]] = arith.cmpf ugt, [[VAL_66]], [[VAL_67]] : f32 - // CHECK: [[VAL_70:%.*]] = select [[VAL_69]], [[VAL_66]], [[VAL_67]] : f32 + // CHECK: [[VAL_70:%.*]] = arith.select [[VAL_69]], [[VAL_66]], [[VAL_67]] : f32 // CHECK: [[VAL_71:%.*]], [[VAL_72:%.*]] = gpu.shuffle xor [[VAL_70]], [[VAL_9]], [[VAL_5]] : f32 // CHECK: [[VAL_73:%.*]] = arith.cmpf ugt, [[VAL_70]], [[VAL_71]] : f32 - // CHECK: [[VAL_74:%.*]] = select [[VAL_73]], [[VAL_70]], [[VAL_71]] : f32 + // CHECK: [[VAL_74:%.*]] = arith.select [[VAL_73]], [[VAL_70]], [[VAL_71]] : f32 // CHECK: [[VAL_75:%.*]], [[VAL_76:%.*]] = gpu.shuffle xor [[VAL_74]], [[VAL_10]], [[VAL_5]] : f32 // CHECK: [[VAL_77:%.*]] = arith.cmpf ugt, [[VAL_74]], [[VAL_75]] : f32 - // CHECK: [[VAL_78:%.*]] = select [[VAL_77]], [[VAL_74]], [[VAL_75]] : f32 + // CHECK: [[VAL_78:%.*]] = arith.select [[VAL_77]], [[VAL_74]], [[VAL_75]] : f32 // CHECK: br ^bb18([[VAL_78]] : f32) // CHECK: ^bb18([[VAL_79:%.*]]: f32): // CHECK: cond_br [[VAL_30]], ^bb19, ^bb20 @@ -129,7 +129,7 @@ // CHECK: cond_br [[VAL_89]], ^bb24, ^bb25 // CHECK: ^bb24: // CHECK: [[VAL_90:%.*]] = arith.cmpf ugt, [[VAL_86]], [[VAL_88]] : f32 - // CHECK: [[VAL_91:%.*]] = select [[VAL_90]], [[VAL_86]], [[VAL_88]] : f32 + // CHECK: [[VAL_91:%.*]] = arith.select [[VAL_90]], [[VAL_86]], [[VAL_88]] : f32 // CHECK: br ^bb26([[VAL_91]] : f32) // CHECK: ^bb25: // CHECK: br ^bb26([[VAL_86]] : f32) @@ -138,7 +138,7 @@ // CHECK: cond_br [[VAL_94]], ^bb27, ^bb28 // CHECK: ^bb27: // CHECK: [[VAL_95:%.*]] = arith.cmpf ugt, [[VAL_92]], [[VAL_93]] : f32 - // CHECK: [[VAL_96:%.*]] = select [[VAL_95]], [[VAL_92]], [[VAL_93]] : f32 + // CHECK: [[VAL_96:%.*]] = arith.select [[VAL_95]], [[VAL_92]], [[VAL_93]] : f32 // CHECK: br ^bb29([[VAL_96]] : f32) // CHECK: ^bb28: // CHECK: br ^bb29([[VAL_92]] : f32) @@ -147,7 +147,7 @@ // CHECK: cond_br [[VAL_99]], ^bb30, ^bb31 // CHECK: ^bb30: // CHECK: [[VAL_100:%.*]] = arith.cmpf ugt, [[VAL_97]], [[VAL_98]] : f32 - // CHECK: [[VAL_101:%.*]] = select [[VAL_100]], [[VAL_97]], [[VAL_98]] : f32 + // CHECK: [[VAL_101:%.*]] = arith.select [[VAL_100]], [[VAL_97]], [[VAL_98]] : f32 // CHECK: br ^bb32([[VAL_101]] : f32) // CHECK: ^bb31: // CHECK: br ^bb32([[VAL_97]] : f32) @@ -156,7 +156,7 @@ // CHECK: cond_br [[VAL_104]], ^bb33, ^bb34 // CHECK: ^bb33: // CHECK: [[VAL_105:%.*]] = arith.cmpf ugt, [[VAL_102]], [[VAL_103]] : f32 - // CHECK: [[VAL_106:%.*]] = select [[VAL_105]], [[VAL_102]], [[VAL_103]] : f32 + // CHECK: [[VAL_106:%.*]] = arith.select [[VAL_105]], [[VAL_102]], [[VAL_103]] : f32 // CHECK: br ^bb35([[VAL_106]] : f32) // CHECK: ^bb34: // CHECK: br ^bb35([[VAL_102]] : f32) @@ -165,7 +165,7 @@ // CHECK: cond_br [[VAL_109]], ^bb36, ^bb37 // CHECK: ^bb36: // CHECK: [[VAL_110:%.*]] = arith.cmpf ugt, [[VAL_107]], [[VAL_108]] : f32 - // CHECK: [[VAL_111:%.*]] = select [[VAL_110]], [[VAL_107]], [[VAL_108]] : f32 + // CHECK: [[VAL_111:%.*]] = arith.select [[VAL_110]], [[VAL_107]], [[VAL_108]] : f32 // CHECK: br ^bb38([[VAL_111]] : f32) // CHECK: ^bb37: // CHECK: br ^bb38([[VAL_107]] : f32) @@ -174,19 +174,19 @@ // CHECK: ^bb39: // CHECK: [[VAL_113:%.*]], [[VAL_114:%.*]] = gpu.shuffle xor [[VAL_86]], [[VAL_6]], [[VAL_5]] : f32 // CHECK: [[VAL_115:%.*]] = arith.cmpf ugt, [[VAL_86]], [[VAL_113]] : f32 - // CHECK: [[VAL_116:%.*]] = select [[VAL_115]], [[VAL_86]], [[VAL_113]] : f32 + // CHECK: [[VAL_116:%.*]] = arith.select [[VAL_115]], [[VAL_86]], [[VAL_113]] : f32 // CHECK: [[VAL_117:%.*]], [[VAL_118:%.*]] = gpu.shuffle xor [[VAL_116]], [[VAL_7]], [[VAL_5]] : f32 // CHECK: [[VAL_119:%.*]] = arith.cmpf ugt, [[VAL_116]], [[VAL_117]] : f32 - // CHECK: [[VAL_120:%.*]] = select [[VAL_119]], [[VAL_116]], [[VAL_117]] : f32 + // CHECK: [[VAL_120:%.*]] = arith.select [[VAL_119]], [[VAL_116]], [[VAL_117]] : f32 // CHECK: [[VAL_121:%.*]], [[VAL_122:%.*]] = gpu.shuffle xor [[VAL_120]], [[VAL_8]], [[VAL_5]] : f32 // CHECK: [[VAL_123:%.*]] = arith.cmpf ugt, [[VAL_120]], [[VAL_121]] : f32 - // CHECK: [[VAL_124:%.*]] = select [[VAL_123]], [[VAL_120]], [[VAL_121]] : f32 + // CHECK: [[VAL_124:%.*]] = arith.select [[VAL_123]], [[VAL_120]], [[VAL_121]] : f32 // CHECK: [[VAL_125:%.*]], [[VAL_126:%.*]] = gpu.shuffle xor [[VAL_124]], [[VAL_9]], [[VAL_5]] : f32 // CHECK: [[VAL_127:%.*]] = arith.cmpf ugt, [[VAL_124]], [[VAL_125]] : f32 - // CHECK: [[VAL_128:%.*]] = select [[VAL_127]], [[VAL_124]], [[VAL_125]] : f32 + // CHECK: [[VAL_128:%.*]] = arith.select [[VAL_127]], [[VAL_124]], [[VAL_125]] : f32 // CHECK: [[VAL_129:%.*]], [[VAL_130:%.*]] = gpu.shuffle xor [[VAL_128]], [[VAL_10]], [[VAL_5]] : f32 // CHECK: [[VAL_131:%.*]] = arith.cmpf ugt, [[VAL_128]], [[VAL_129]] : f32 - // CHECK: [[VAL_132:%.*]] = select [[VAL_131]], [[VAL_128]], [[VAL_129]] : f32 + // CHECK: [[VAL_132:%.*]] = arith.select [[VAL_131]], [[VAL_128]], [[VAL_129]] : f32 // CHECK: br ^bb40([[VAL_132]] : f32) // CHECK: ^bb40([[VAL_133:%.*]]: f32): // CHECK: store [[VAL_133]], [[VAL_1]]{{\[}}[[VAL_4]]] : memref<32xf32, 3> diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -1724,9 +1724,9 @@ %cst = arith.constant 0.0 : f32 %idx = arith.constant 0 : index - // CHECK: select %{{.*}}, %[[t1]], %[[t2]] + // CHECK: arith.select %{{.*}}, %[[t1]], %[[t2]] // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "true"]} - %s = std.select %c, %t1, %t2 : tensor + %s = arith.select %c, %t1, %t2 : tensor // CHECK: tensor.insert // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "none"]} %w = tensor.insert %cst into %s[%idx] : tensor @@ -1750,9 +1750,9 @@ %cst = arith.constant 0.0 : f32 %idx = arith.constant 0 : index - // CHECK: select %{{.*}}, %[[t1]], %[[t2]] + // CHECK: arith.select %{{.*}}, %[[t1]], %[[t2]] // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "false"]} - %s = std.select %c, %t1, %t2 : tensor + %s = arith.select %c, %t1, %t2 : tensor // CHECK: tensor.insert // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "none"]} %w = tensor.insert %cst into %s[%idx] : tensor @@ -1779,9 +1779,9 @@ %cst = arith.constant 0.0 : f32 %idx = arith.constant 0 : index - // CHECK: select %{{.*}}, %[[t1]], %[[t2]] + // CHECK: arith.select %{{.*}}, %[[t1]], %[[t2]] // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "true"]} - %s = std.select %c, %t1, %t2 : tensor + %s = arith.select %c, %t1, %t2 : tensor // CHECK: tensor.insert // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "none"]} %w = tensor.insert %cst into %s[%idx] : tensor diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1009,7 +1009,7 @@ %B : tensor<4xf32> {linalg.inplaceable = false}) -> tensor<4xf32> { - // CHECK: %[[r:.*]] = select %[[cond]], %[[A]], %[[B]] + // CHECK: %[[r:.*]] = arith.select %[[cond]], %[[A]], %[[B]] %r = scf.if %b -> (tensor<4xf32>) { scf.yield %A : tensor<4xf32> } else { @@ -1321,8 +1321,8 @@ // CHECK: memref.copy %[[t1]], %[[alloc]] // CHECK: memref.store %{{.*}}, %[[alloc]] %w = tensor.insert %cst into %t1[%idx] : tensor - // CHECK: %[[select:.*]] = select %{{.*}}, %[[t1]], %[[t2]] - %s = std.select %c, %t1, %t2 : tensor + // CHECK: %[[select:.*]] = arith.select %{{.*}}, %[[t1]], %[[t2]] + %s = arith.select %c, %t1, %t2 : tensor // CHECK: return %[[select]], %[[alloc]] return %s, %w : tensor, tensor } @@ -1343,8 +1343,8 @@ // CHECK: %[[alloc:.*]] = memref.alloc // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] // CHECK: memref.copy %[[t1]], %[[alloc]] - // CHECK: %[[select:.*]] = select %{{.*}}, %[[casted]], %[[t2]] - %s = std.select %c, %t1, %t2 : tensor + // CHECK: %[[select:.*]] = arith.select %{{.*}}, %[[casted]], %[[t2]] + %s = arith.select %c, %t1, %t2 : tensor // CHECK: memref.store %{{.*}}, %[[select]] %w = tensor.insert %cst into %s[%idx] : tensor diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir --- a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir +++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir @@ -62,8 +62,8 @@ // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] // CHECK-SAME: outs(%[[ARG1]] // CHECK: ^bb0(%[[PRED:.*]]: i1, %[[TRUE_VAL:.*]]: i32, %[[FALSE_VAL:.*]]: i32, %{{.*}}: i32): - // CHECK: select %[[PRED]], %[[TRUE_VAL]], %[[FALSE_VAL]] : i32 - %0 = select %arg0, %arg1, %arg2 : tensor, tensor + // CHECK: arith.select %[[PRED]], %[[TRUE_VAL]], %[[FALSE_VAL]] : i32 + %0 = arith.select %arg0, %arg1, %arg2 : tensor, tensor return %0 : tensor } diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -533,7 +533,7 @@ %i = linalg.index 0 : index %0 = arith.constant 0 : index %1 = arith.cmpi eq, %0, %i : index - %2 = select %1, %b, %c : f32 + %2 = arith.select %1, %b, %c : f32 %3 = arith.addf %a, %2 : f32 linalg.yield %3 : f32 } @@ -547,7 +547,7 @@ // CHECK: %[[a:.*]] = memref.load %[[ARG0]][%[[i]]] // CHECK: %[[b:.*]] = memref.load %[[ARG1]][] // CHECK: %[[c:.*]] = memref.load %[[ARG2]][] -// CHECK: %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]] +// CHECK: %[[d:.*]] = arith.select %{{.*}}, %[[b]], %[[c]] // CHECK: %[[e:.*]] = arith.addf %[[a]], %[[d]] // CHECK: store %[[e]], %[[ARG2]][] @@ -559,7 +559,7 @@ // CHECKPARALLEL: %[[a:.*]] = memref.load %[[ARG0]][%[[i]]] // CHECKPARALLEL: %[[b:.*]] = memref.load %[[ARG1]][] // CHECKPARALLEL: %[[c:.*]] = memref.load %[[ARG2]][] -// CHECKPARALLEL: %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]] +// CHECKPARALLEL: %[[d:.*]] = arith.select %{{.*}}, %[[b]], %[[c]] // CHECKPARALLEL: %[[e:.*]] = arith.addf %[[a]], %[[d]] // CHECKPARALLEL: store %[[e]], %[[ARG2]][] diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -253,12 +253,12 @@ outs(%1, %3 : tensor, tensor) { ^bb0(%arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32): %5 = arith.cmpi sge, %arg3, %arg5 : i32 - %6 = select %5, %arg3, %arg5 : i32 + %6 = arith.select %5, %arg3, %arg5 : i32 %7 = arith.cmpi eq, %arg3, %arg5 : i32 %8 = arith.cmpi slt, %arg4, %arg6 : i32 - %9 = select %8, %arg4, %arg6 : i32 - %10 = select %5, %arg4, %arg6 : i32 - %11 = select %7, %9, %10 : i32 + %9 = arith.select %8, %arg4, %arg6 : i32 + %10 = arith.select %5, %arg4, %arg6 : i32 + %11 = arith.select %7, %9, %10 : i32 linalg.yield %6, %11 : i32, i32 } -> (tensor, tensor) return %4#0, %4#1 : tensor, tensor diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -328,8 +328,8 @@ %11 = arith.mulf %arg5, %8 : f32 // CHECK: %[[RSQRT:.*]] = math.rsqrt %[[V3]] : vector<4x256xf32> %12 = math.rsqrt %arg5 : f32 - // CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> - %13 = select %7, %arg5, %arg6 : f32 + // CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> + %13 = arith.select %7, %arg5, %arg6 : f32 // CHECK: %[[SUB:.*]] = arith.subf %[[V3]], %[[V0]] : vector<4x256xf32> %14 = arith.subf %arg5, %arg4 : f32 // CHECK: %[[TAN:.*]] = math.tanh %[[V3]] : vector<4x256xf32> @@ -406,8 +406,8 @@ %11 = arith.mulf %arg5, %8 : f32 // CHECK: %[[RSQRT:.*]] = math.rsqrt %[[V3]] : vector<4x256xf32> %12 = math.rsqrt %arg5 : f32 - // CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> - %13 = select %7, %arg5, %arg6 : f32 + // CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> + %13 = arith.select %7, %arg5, %arg6 : f32 // CHECK: %[[SUB:.*]] = arith.subf %[[V3]], %[[V0]] : vector<4x256xf32> %14 = arith.subf %arg5, %arg4 : f32 // CHECK: %[[TAN:.*]] = math.tanh %[[V3]] : vector<4x256xf32> diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -39,27 +39,27 @@ // CHECK-DAG: %[[val_cst_28:.*]] = arith.constant 3.750000e+00 : f32 // CHECK: %[[val_0:.*]] = arith.cmpf olt, %[[val_arg0]], %[[val_cst]] : f32 // CHECK: %[[val_1:.*]] = arith.negf %[[val_arg0]] : f32 -// CHECK: %[[val_2:.*]] = select %[[val_0]], %[[val_1]], %[[val_arg0]] : f32 +// CHECK: %[[val_2:.*]] = arith.select %[[val_0]], %[[val_1]], %[[val_arg0]] : f32 // CHECK: %[[val_3:.*]] = arith.cmpf olt, %[[val_2]], %[[val_cst_26]] : f32 -// CHECK: %[[val_4:.*]] = select %[[val_3]], %[[val_cst_1]], %[[val_cst_5]] : f32 -// CHECK: %[[val_5:.*]] = select %[[val_3]], %[[val_cst_14]], %[[val_cst_18]] : f32 -// CHECK: %[[val_6:.*]] = select %[[val_3]], %[[val_cst_2]], %[[val_cst_6]] : f32 -// CHECK: %[[val_7:.*]] = select %[[val_3]], %[[val_cst_15]], %[[val_cst_19]] : f32 -// CHECK: %[[val_8:.*]] = select %[[val_3]], %[[val_cst_3]], %[[val_cst_7]] : f32 -// CHECK: %[[val_9:.*]] = select %[[val_3]], %[[val_cst_16]], %[[val_cst_20]] : f32 -// CHECK: %[[val_10:.*]] = select %[[val_3]], %[[val_cst_4]], %[[val_cst_8]] : f32 -// CHECK: %[[val_11:.*]] = select %[[val_3]], %[[val_cst_17]], %[[val_cst_21]] : f32 +// CHECK: %[[val_4:.*]] = arith.select %[[val_3]], %[[val_cst_1]], %[[val_cst_5]] : f32 +// CHECK: %[[val_5:.*]] = arith.select %[[val_3]], %[[val_cst_14]], %[[val_cst_18]] : f32 +// CHECK: %[[val_6:.*]] = arith.select %[[val_3]], %[[val_cst_2]], %[[val_cst_6]] : f32 +// CHECK: %[[val_7:.*]] = arith.select %[[val_3]], %[[val_cst_15]], %[[val_cst_19]] : f32 +// CHECK: %[[val_8:.*]] = arith.select %[[val_3]], %[[val_cst_3]], %[[val_cst_7]] : f32 +// CHECK: %[[val_9:.*]] = arith.select %[[val_3]], %[[val_cst_16]], %[[val_cst_20]] : f32 +// CHECK: %[[val_10:.*]] = arith.select %[[val_3]], %[[val_cst_4]], %[[val_cst_8]] : f32 +// CHECK: %[[val_11:.*]] = arith.select %[[val_3]], %[[val_cst_17]], %[[val_cst_21]] : f32 // CHECK: %[[val_12:.*]] = arith.cmpf olt, %[[val_2]], %[[val_cst_27]] : f32 -// CHECK: %[[val_13:.*]] = select %[[val_12]], %[[val_cst]], %[[val_cst_9]] : f32 -// CHECK: %[[val_14:.*]] = select %[[val_12]], %[[val_4]], %[[val_cst_10]] : f32 -// CHECK: %[[val_15:.*]] = select %[[val_12]], %[[val_5]], %[[val_cst_22]] : f32 -// CHECK: %[[val_16:.*]] = select %[[val_12]], %[[val_6]], %[[val_cst_11]] : f32 -// CHECK: %[[val_17:.*]] = select %[[val_12]], %[[val_7]], %[[val_cst_23]] : f32 -// CHECK: %[[val_18:.*]] = select %[[val_12]], %[[val_8]], %[[val_cst_12]] : f32 -// CHECK: %[[val_19:.*]] = select %[[val_12]], %[[val_9]], %[[val_cst_24]] : f32 -// CHECK: %[[val_20:.*]] = select %[[val_12]], %[[val_10]], %[[val_cst_13]] : f32 -// CHECK: %[[val_21:.*]] = select %[[val_12]], %[[val_11]], %[[val_cst_25]] : f32 -// CHECK: %[[val_22:.*]] = select %[[val_12]], %[[val_cst]], %[[val_cst_0]] : f32 +// CHECK: %[[val_13:.*]] = arith.select %[[val_12]], %[[val_cst]], %[[val_cst_9]] : f32 +// CHECK: %[[val_14:.*]] = arith.select %[[val_12]], %[[val_4]], %[[val_cst_10]] : f32 +// CHECK: %[[val_15:.*]] = arith.select %[[val_12]], %[[val_5]], %[[val_cst_22]] : f32 +// CHECK: %[[val_16:.*]] = arith.select %[[val_12]], %[[val_6]], %[[val_cst_11]] : f32 +// CHECK: %[[val_17:.*]] = arith.select %[[val_12]], %[[val_7]], %[[val_cst_23]] : f32 +// CHECK: %[[val_18:.*]] = arith.select %[[val_12]], %[[val_8]], %[[val_cst_12]] : f32 +// CHECK: %[[val_19:.*]] = arith.select %[[val_12]], %[[val_9]], %[[val_cst_24]] : f32 +// CHECK: %[[val_20:.*]] = arith.select %[[val_12]], %[[val_10]], %[[val_cst_13]] : f32 +// CHECK: %[[val_21:.*]] = arith.select %[[val_12]], %[[val_11]], %[[val_cst_25]] : f32 +// CHECK: %[[val_22:.*]] = arith.select %[[val_12]], %[[val_cst]], %[[val_cst_0]] : f32 // CHECK: %[[val_23:.*]] = arith.cmpf ult, %[[val_2]], %[[val_cst_28]] : f32 // CHECK: %[[val_24:.*]] = math.fma %[[val_2]], %[[val_20]], %[[val_18]] : f32 // CHECK: %[[val_25:.*]] = math.fma %[[val_2]], %[[val_24]], %[[val_16]] : f32 @@ -71,9 +71,9 @@ // CHECK: %[[val_31:.*]] = math.fma %[[val_2]], %[[val_30]], %[[val_cst_0]] : f32 // CHECK: %[[val_32:.*]] = arith.divf %[[val_27]], %[[val_31]] : f32 // CHECK: %[[val_33:.*]] = arith.addf %[[val_22]], %[[val_32]] : f32 -// CHECK: %[[val_34:.*]] = select %[[val_23]], %[[val_33]], %[[val_cst_0]] : f32 +// CHECK: %[[val_34:.*]] = arith.select %[[val_23]], %[[val_33]], %[[val_cst_0]] : f32 // CHECK: %[[val_35:.*]] = arith.negf %[[val_34]] : f32 -// CHECK: %[[val_36:.*]] = select %[[val_0]], %[[val_35]], %[[val_34]] : f32 +// CHECK: %[[val_36:.*]] = arith.select %[[val_0]], %[[val_35]], %[[val_34]] : f32 // CHECK: return %[[val_36]] : f32 // CHECK: } func @erf_scalar(%arg0: f32) -> f32 { @@ -86,7 +86,7 @@ // CHECK: %[[zero:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32> // CHECK-NOT: erf // CHECK-COUNT-20: select -// CHECK: %[[res:.*]] = select +// CHECK: %[[res:.*]] = arith.select // CHECK: return %[[res]] : vector<8xf32> // CHECK: } func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> { @@ -132,10 +132,10 @@ // CHECK: %[[VAL_34:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_10]] : f32 // CHECK: %[[VAL_35:.*]] = arith.cmpf ogt, %[[VAL_0]], %[[VAL_9]] : f32 // CHECK: %[[VAL_36:.*]] = arith.andi %[[VAL_31]], %[[VAL_32]] : i1 -// CHECK: %[[VAL_37:.*]] = select %[[VAL_35]], %[[VAL_10]], %[[VAL_12]] : f32 -// CHECK: %[[VAL_38:.*]] = select %[[VAL_36]], %[[VAL_30]], %[[VAL_37]] : f32 -// CHECK: %[[VAL_39:.*]] = select %[[VAL_34]], %[[VAL_10]], %[[VAL_38]] : f32 -// CHECK: %[[VAL_40:.*]] = select %[[VAL_33]], %[[VAL_9]], %[[VAL_39]] : f32 +// CHECK: %[[VAL_37:.*]] = arith.select %[[VAL_35]], %[[VAL_10]], %[[VAL_12]] : f32 +// CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_30]], %[[VAL_37]] : f32 +// CHECK: %[[VAL_39:.*]] = arith.select %[[VAL_34]], %[[VAL_10]], %[[VAL_38]] : f32 +// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_33]], %[[VAL_9]], %[[VAL_39]] : f32 // CHECK: return %[[VAL_40]] : f32 func @exp_scalar(%arg0: f32) -> f32 { %0 = math.exp %arg0 : f32 @@ -147,7 +147,7 @@ // CHECK: %[[VAL_1:.*]] = arith.constant dense<0.693147182> : vector<8xf32> // CHECK-NOT: exp // CHECK-COUNT-3: select -// CHECK: %[[VAL_40:.*]] = select +// CHECK: %[[VAL_40:.*]] = arith.select // CHECK: return %[[VAL_40]] : vector<8xf32> func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> { %0 = math.exp %arg0 : vector<8xf32> @@ -162,19 +162,19 @@ // CHECK: %[[BEGIN_EXP_X:.*]] = arith.mulf %[[X]], %[[CST_LOG2E]] : f32 // CHECK-NOT: exp // CHECK-COUNT-3: select -// CHECK: %[[EXP_X:.*]] = select +// CHECK: %[[EXP_X:.*]] = arith.select // CHECK: %[[VAL_58:.*]] = arith.cmpf oeq, %[[EXP_X]], %[[CST_ONE]] : f32 // CHECK: %[[VAL_59:.*]] = arith.subf %[[EXP_X]], %[[CST_ONE]] : f32 // CHECK: %[[VAL_60:.*]] = arith.cmpf oeq, %[[VAL_59]], %[[CST_MINUSONE]] : f32 // CHECK-NOT: log // CHECK-COUNT-5: select -// CHECK: %[[LOG_U:.*]] = select +// CHECK: %[[LOG_U:.*]] = arith.select // CHECK: %[[VAL_104:.*]] = arith.cmpf oeq, %[[LOG_U]], %[[EXP_X]] : f32 // CHECK: %[[VAL_105:.*]] = arith.divf %[[X]], %[[LOG_U]] : f32 // CHECK: %[[VAL_106:.*]] = arith.mulf %[[VAL_59]], %[[VAL_105]] : f32 -// CHECK: %[[VAL_107:.*]] = select %[[VAL_104]], %[[EXP_X]], %[[VAL_106]] : f32 -// CHECK: %[[VAL_108:.*]] = select %[[VAL_60]], %[[CST_MINUSONE]], %[[VAL_107]] : f32 -// CHECK: %[[VAL_109:.*]] = select %[[VAL_58]], %[[X]], %[[VAL_108]] : f32 +// CHECK: %[[VAL_107:.*]] = arith.select %[[VAL_104]], %[[EXP_X]], %[[VAL_106]] : f32 +// CHECK: %[[VAL_108:.*]] = arith.select %[[VAL_60]], %[[CST_MINUSONE]], %[[VAL_107]] : f32 +// CHECK: %[[VAL_109:.*]] = arith.select %[[VAL_58]], %[[X]], %[[VAL_108]] : f32 // CHECK: return %[[VAL_109]] : f32 // CHECK: } func @expm1_scalar(%arg0: f32) -> f32 { @@ -191,7 +191,7 @@ // CHECK-COUNT-5: select // CHECK-NOT: expm1 // CHECK-COUNT-3: select -// CHECK: %[[VAL_115:.*]] = select +// CHECK: %[[VAL_115:.*]] = arith.select // CHECK: return %[[VAL_115]] : vector<8x8xf32> // CHECK: } func @expm1_vector(%arg0: vector<8x8xf32>) -> vector<8x8xf32> { @@ -224,7 +224,7 @@ // CHECK: %[[VAL_21:.*]] = arith.constant 23 : i32 // CHECK: %[[VAL_22:.*]] = arith.constant 0.693147182 : f32 // CHECK: %[[VAL_23:.*]] = arith.cmpf ogt, %[[X]], %[[VAL_4]] : f32 -// CHECK: %[[VAL_24:.*]] = select %[[VAL_23]], %[[X]], %[[VAL_4]] : f32 +// CHECK: %[[VAL_24:.*]] = arith.select %[[VAL_23]], %[[X]], %[[VAL_4]] : f32 // CHECK-NOT: frexp // CHECK: %[[VAL_25:.*]] = arith.bitcast %[[VAL_24]] : f32 to i32 // CHECK: %[[VAL_26:.*]] = arith.andi %[[VAL_25]], %[[VAL_19]] : i32 @@ -235,9 +235,9 @@ // CHECK: %[[VAL_31:.*]] = arith.sitofp %[[VAL_30]] : i32 to f32 // CHECK: %[[VAL_32:.*]] = arith.subf %[[VAL_31]], %[[VAL_18]] : f32 // CHECK: %[[VAL_33:.*]] = arith.cmpf olt, %[[VAL_28]], %[[VAL_8]] : f32 -// CHECK: %[[VAL_34:.*]] = select %[[VAL_33]], %[[VAL_28]], %[[VAL_1]] : f32 +// CHECK: %[[VAL_34:.*]] = arith.select %[[VAL_33]], %[[VAL_28]], %[[VAL_1]] : f32 // CHECK: %[[VAL_35:.*]] = arith.subf %[[VAL_28]], %[[VAL_2]] : f32 -// CHECK: %[[VAL_36:.*]] = select %[[VAL_33]], %[[VAL_2]], %[[VAL_1]] : f32 +// CHECK: %[[VAL_36:.*]] = arith.select %[[VAL_33]], %[[VAL_2]], %[[VAL_1]] : f32 // CHECK: %[[VAL_37:.*]] = arith.subf %[[VAL_32]], %[[VAL_36]] : f32 // CHECK: %[[VAL_38:.*]] = arith.addf %[[VAL_35]], %[[VAL_34]] : f32 // CHECK: %[[VAL_39:.*]] = arith.mulf %[[VAL_38]], %[[VAL_38]] : f32 @@ -257,9 +257,9 @@ // CHECK: %[[VAL_53:.*]] = arith.cmpf ult, %[[X]], %[[VAL_1]] : f32 // CHECK: %[[VAL_54:.*]] = arith.cmpf oeq, %[[X]], %[[VAL_1]] : f32 // CHECK: %[[VAL_55:.*]] = arith.cmpf oeq, %[[X]], %[[VAL_6]] : f32 -// CHECK: %[[VAL_56:.*]] = select %[[VAL_55]], %[[VAL_6]], %[[VAL_52]] : f32 -// CHECK: %[[VAL_57:.*]] = select %[[VAL_53]], %[[VAL_7]], %[[VAL_56]] : f32 -// CHECK: %[[VAL_58:.*]] = select %[[VAL_54]], %[[VAL_5]], %[[VAL_57]] : f32 +// CHECK: %[[VAL_56:.*]] = arith.select %[[VAL_55]], %[[VAL_6]], %[[VAL_52]] : f32 +// CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_53]], %[[VAL_7]], %[[VAL_56]] : f32 +// CHECK: %[[VAL_58:.*]] = arith.select %[[VAL_54]], %[[VAL_5]], %[[VAL_57]] : f32 // CHECK: return %[[VAL_58]] : f32 // CHECK: } func @log_scalar(%arg0: f32) -> f32 { @@ -271,7 +271,7 @@ // CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { // CHECK: %[[CST_LN2:.*]] = arith.constant dense<0.693147182> : vector<8xf32> // CHECK-COUNT-5: select -// CHECK: %[[VAL_71:.*]] = select +// CHECK: %[[VAL_71:.*]] = arith.select // CHECK: return %[[VAL_71]] : vector<8xf32> // CHECK: } func @log_vector(%arg0: vector<8xf32>) -> vector<8xf32> { @@ -283,7 +283,7 @@ // CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 { // CHECK: %[[CST_LOG2E:.*]] = arith.constant 1.44269502 : f32 // CHECK-COUNT-5: select -// CHECK: %[[VAL_65:.*]] = select +// CHECK: %[[VAL_65:.*]] = arith.select // CHECK: return %[[VAL_65]] : f32 // CHECK: } func @log2_scalar(%arg0: f32) -> f32 { @@ -295,7 +295,7 @@ // CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { // CHECK: %[[CST_LOG2E:.*]] = arith.constant dense<1.44269502> : vector<8xf32> // CHECK-COUNT-5: select -// CHECK: %[[VAL_71:.*]] = select +// CHECK: %[[VAL_71:.*]] = arith.select // CHECK: return %[[VAL_71]] : vector<8xf32> // CHECK: } func @log2_vector(%arg0: vector<8xf32>) -> vector<8xf32> { @@ -310,13 +310,13 @@ // CHECK: %[[U_SMALL:.*]] = arith.cmpf oeq, %[[U]], %[[CST_ONE]] : f32 // CHECK-NOT: log // CHECK-COUNT-5: select -// CHECK: %[[LOG_U:.*]] = select +// CHECK: %[[LOG_U:.*]] = arith.select // CHECK: %[[U_INF:.*]] = arith.cmpf oeq, %[[U]], %[[LOG_U]] : f32 // CHECK: %[[VAL_69:.*]] = arith.subf %[[U]], %[[CST_ONE]] : f32 // CHECK: %[[VAL_70:.*]] = arith.divf %[[LOG_U]], %[[VAL_69]] : f32 // CHECK: %[[LOG_LARGE:.*]] = arith.mulf %[[X]], %[[VAL_70]] : f32 // CHECK: %[[VAL_72:.*]] = arith.ori %[[U_SMALL]], %[[U_INF]] : i1 -// CHECK: %[[APPROX:.*]] = select %[[VAL_72]], %[[X]], %[[LOG_LARGE]] : f32 +// CHECK: %[[APPROX:.*]] = arith.select %[[VAL_72]], %[[X]], %[[LOG_LARGE]] : f32 // CHECK: return %[[APPROX]] : f32 // CHECK: } func @log1p_scalar(%arg0: f32) -> f32 { @@ -328,7 +328,7 @@ // CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { // CHECK: %[[CST_ONE:.*]] = arith.constant dense<1.000000e+00> : vector<8xf32> // CHECK-COUNT-6: select -// CHECK: %[[VAL_79:.*]] = select +// CHECK: %[[VAL_79:.*]] = arith.select // CHECK: return %[[VAL_79]] : vector<8xf32> // CHECK: } func @log1p_vector(%arg0: vector<8xf32>) -> vector<8xf32> { @@ -354,9 +354,9 @@ // CHECK: %[[VAL_13:.*]] = arith.constant 1.18534706E-4 : f32 // CHECK: %[[VAL_14:.*]] = arith.constant 1.19825836E-6 : f32 // CHECK: %[[VAL_15:.*]] = arith.cmpf olt, %[[VAL_0]], %[[VAL_2]] : f32 -// CHECK: %[[VAL_16:.*]] = select %[[VAL_15]], %[[VAL_0]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_16:.*]] = arith.select %[[VAL_15]], %[[VAL_0]], %[[VAL_2]] : f32 // CHECK: %[[VAL_17:.*]] = arith.cmpf ogt, %[[VAL_16]], %[[VAL_1]] : f32 -// CHECK: %[[VAL_18:.*]] = select %[[VAL_17]], %[[VAL_16]], %[[VAL_1]] : f32 +// CHECK: %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_16]], %[[VAL_1]] : f32 // CHECK: %[[VAL_19:.*]] = math.abs %[[VAL_0]] : f32 // CHECK: %[[VAL_20:.*]] = arith.cmpf olt, %[[VAL_19]], %[[VAL_3]] : f32 // CHECK: %[[VAL_21:.*]] = arith.mulf %[[VAL_18]], %[[VAL_18]] : f32 @@ -371,7 +371,7 @@ // CHECK: %[[VAL_30:.*]] = math.fma %[[VAL_21]], %[[VAL_29]], %[[VAL_12]] : f32 // CHECK: %[[VAL_31:.*]] = math.fma %[[VAL_21]], %[[VAL_30]], %[[VAL_11]] : f32 // CHECK: %[[VAL_32:.*]] = arith.divf %[[VAL_28]], %[[VAL_31]] : f32 -// CHECK: %[[VAL_33:.*]] = select %[[VAL_20]], %[[VAL_18]], %[[VAL_32]] : f32 +// CHECK: %[[VAL_33:.*]] = arith.select %[[VAL_20]], %[[VAL_18]], %[[VAL_32]] : f32 // CHECK: return %[[VAL_33]] : f32 // CHECK: } func @tanh_scalar(%arg0: f32) -> f32 { @@ -384,7 +384,7 @@ // CHECK: %[[VAL_1:.*]] = arith.constant dense<-7.99881172> : vector<8xf32> // CHECK-NOT: tanh // CHECK-COUNT-2: select -// CHECK: %[[VAL_33:.*]] = select +// CHECK: %[[VAL_33:.*]] = arith.select // CHECK: return %[[VAL_33]] : vector<8xf32> // CHECK: } func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> { @@ -418,7 +418,7 @@ // AVX2: %[[VAL_10:.*]] = arith.mulf %[[VAL_5]], %[[VAL_9]] : vector<8xf32> // AVX2: %[[VAL_11:.*]] = math.fma %[[VAL_9]], %[[VAL_10]], %[[VAL_2]] : vector<8xf32> // AVX2: %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : vector<8xf32> -// AVX2: %[[VAL_13:.*]] = select %[[VAL_8]], %[[VAL_9]], %[[VAL_12]] : vector<8xi1>, vector<8xf32> +// AVX2: %[[VAL_13:.*]] = arith.select %[[VAL_8]], %[[VAL_9]], %[[VAL_12]] : vector<8xi1>, vector<8xf32> // AVX2: return %[[VAL_13]] : vector<8xf32> // AVX2: } func @rsqrt_vector_8xf32(%arg0: vector<8xf32>) -> vector<8xf32> { @@ -518,13 +518,13 @@ // CHECK-DAG: %[[ABS:.+]] = math.abs %arg0 // CHECK-DAG: %[[DIV:.+]] = arith.divf %cst, %[[ABS]] // CHECK-DAG: %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]] -// CHECK-DAG: %[[SEL:.+]] = select %[[CMP]], %[[ABS]], %[[DIV]] +// CHECK-DAG: %[[SEL:.+]] = arith.select %[[CMP]], %[[ABS]], %[[DIV]] // CHECK-DAG: %[[P0:.+]] = math.fma %[[SEL]], %[[N1]], %[[N2]] // CHECK-DAG: %[[P1:.+]] = math.fma %[[SEL]], %[[P0]], %[[N3]] // CHECK-DAG: %[[P2:.+]] = math.fma %[[SEL]], %[[P1]], %[[N4]] // CHECK-DAG: %[[P3:.+]] = arith.mulf %[[SEL]], %[[P2]] // CHECK-DAG: %[[SUB:.+]] = arith.subf %[[HALF_PI]], %[[P3]] -// CHECK-DAG: %[[EST:.+]] = select %[[CMP]], %[[P3]], %[[SUB]] +// CHECK-DAG: %[[EST:.+]] = arith.select %[[CMP]], %[[P3]], %[[SUB]] // CHECK-DAG: %[[RES:.+]] = math.copysign %[[EST]], %arg0 // CHECK: return %[[RES]] func @atan_scalar(%arg0: f32) -> f32 { @@ -546,13 +546,13 @@ // CHECK-DAG: %[[ABS:.+]] = math.abs %[[RATIO]] // CHECK-DAG: %[[DIV:.+]] = arith.divf %cst, %[[ABS]] // CHECK-DAG: %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]] -// CHECK-DAG: %[[SEL:.+]] = select %[[CMP]], %[[ABS]], %[[DIV]] +// CHECK-DAG: %[[SEL:.+]] = arith.select %[[CMP]], %[[ABS]], %[[DIV]] // CHECK-DAG: %[[P0:.+]] = math.fma %[[SEL]], %[[N1]], %[[N2]] // CHECK-DAG: %[[P1:.+]] = math.fma %[[SEL]], %[[P0]], %[[N3]] // CHECK-DAG: %[[P2:.+]] = math.fma %[[SEL]], %[[P1]], %[[N4]] // CHECK-DAG: %[[P3:.+]] = arith.mulf %[[SEL]], %[[P2]] // CHECK-DAG: %[[SUB:.+]] = arith.subf %[[HALF_PI]], %[[P3]] -// CHECK-DAG: %[[EST:.+]] = select %[[CMP]], %[[P3]], %[[SUB]] +// CHECK-DAG: %[[EST:.+]] = arith.select %[[CMP]], %[[P3]], %[[SUB]] // CHECK-DAG: %[[ATAN:.+]] = math.copysign %[[EST]], %[[RATIO]] // Handle the case of x < 0: @@ -561,27 +561,27 @@ // CHECK-DAG: %[[ADD_PI:.+]] = arith.addf %[[ATAN]], %[[PI]] // CHECK-DAG: %[[SUB_PI:.+]] = arith.subf %[[ATAN]], %[[PI]] // CHECK-DAG: %[[CMP_ATAN:.+]] = arith.cmpf ogt, %[[ATAN]], %[[ZERO]] -// CHECK-DAG: %[[ATAN_ADJUST:.+]] = select %[[CMP_ATAN]], %[[SUB_PI]], %[[ADD_PI]] +// CHECK-DAG: %[[ATAN_ADJUST:.+]] = arith.select %[[CMP_ATAN]], %[[SUB_PI]], %[[ADD_PI]] // CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %arg1, %[[ZERO]] -// CHECK-DAG: %[[ATAN_EST:.+]] = select %[[X_NEG]], %[[ATAN]], %[[ATAN_ADJUST]] +// CHECK-DAG: %[[ATAN_EST:.+]] = arith.select %[[X_NEG]], %[[ATAN]], %[[ATAN_ADJUST]] // Handle PI / 2 edge case: // CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %arg1, %[[ZERO]] // CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %arg0, %[[ZERO]] // CHECK-DAG: %[[IS_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_POS]] -// CHECK-DAG: %[[EDGE1:.+]] = select %[[IS_HALF_PI]], %[[HALF_PI]], %[[ATAN_EST]] +// CHECK-DAG: %[[EDGE1:.+]] = arith.select %[[IS_HALF_PI]], %[[HALF_PI]], %[[ATAN_EST]] // Handle -PI / 2 edge case: // CHECK-DAG: %[[NEG_HALF_PI:.+]] = arith.constant -1.57079637 // CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %arg0, %[[ZERO]] // CHECK-DAG: %[[IS_NEG_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_NEG]] -// CHECK-DAG: %[[EDGE2:.+]] = select %[[IS_NEG_HALF_PI]], %[[NEG_HALF_PI]], %[[EDGE1]] +// CHECK-DAG: %[[EDGE2:.+]] = arith.select %[[IS_NEG_HALF_PI]], %[[NEG_HALF_PI]], %[[EDGE1]] // Handle Nan edgecase: // CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %arg0, %[[ZERO]] // CHECK-DAG: %[[X_Y_ZERO:.+]] = arith.andi %[[X_ZERO]], %[[Y_ZERO]] // CHECK-DAG: %[[NAN:.+]] = arith.constant 0x7FC00000 -// CHECK-DAG: %[[EDGE3:.+]] = select %[[X_Y_ZERO]], %[[NAN]], %[[EDGE2]] +// CHECK-DAG: %[[EDGE3:.+]] = arith.select %[[X_Y_ZERO]], %[[NAN]], %[[EDGE2]] // CHECK: return %[[EDGE3]] func @atan2_scalar(%arg0: f32, %arg1: f32) -> f32 { diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir --- a/mlir/test/Dialect/MemRef/expand-ops.mlir +++ b/mlir/test/Dialect/MemRef/expand-ops.mlir @@ -9,7 +9,7 @@ // CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { // CHECK: ^bb0([[CUR_VAL:%.*]]: f32): // CHECK: [[CMP:%.*]] = arith.cmpf ogt, [[CUR_VAL]], [[f]] : f32 -// CHECK: [[SELECT:%.*]] = select [[CMP]], [[CUR_VAL]], [[f]] : f32 +// CHECK: [[SELECT:%.*]] = arith.select [[CMP]], [[CUR_VAL]], [[f]] : f32 // CHECK: memref.atomic_yield [[SELECT]] : f32 // CHECK: } // CHECK: return %0 : f32 diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -280,7 +280,7 @@ // CHECK-LABEL: func @to_select1 // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index // CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index -// CHECK: [[V0:%.*]] = select {{.*}}, [[C0]], [[C1]] +// CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]] // CHECK: return [[V0]] : index // ----- @@ -299,7 +299,7 @@ // CHECK-LABEL: func @to_select_same_val // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index // CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index -// CHECK: [[V0:%.*]] = select {{.*}}, [[C0]], [[C1]] +// CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]] // CHECK: return [[V0]], [[C1]] : index, index // ----- @@ -322,8 +322,8 @@ // CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index // CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index // CHECK-DAG: [[C3:%.*]] = arith.constant 3 : index -// CHECK: [[V0:%.*]] = select {{.*}}, [[C0]], [[C2]] -// CHECK: [[V1:%.*]] = select {{.*}}, [[C1]], [[C3]] +// CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C2]] +// CHECK: [[V1:%.*]] = arith.select {{.*}}, [[C1]], [[C3]] // CHECK: return [[V0]], [[V1]] : index // ----- diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -8,9 +8,9 @@ scf.for %i0 = %arg0 to %arg1 step %arg2 { scf.for %i1 = %arg0 to %arg1 step %arg2 { %min_cmp = arith.cmpi slt, %i0, %i1 : index - %min = select %min_cmp, %i0, %i1 : index + %min = arith.select %min_cmp, %i0, %i1 : index %max_cmp = arith.cmpi sge, %i0, %i1 : index - %max = select %max_cmp, %i0, %i1 : index + %max = arith.select %max_cmp, %i0, %i1 : index scf.for %i2 = %min to %max step %i1 { } } @@ -21,9 +21,9 @@ // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : index -// CHECK-NEXT: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : index +// CHECK-NEXT: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : index // CHECK-NEXT: %{{.*}} = arith.cmpi sge, %{{.*}}, %{{.*}} : index -// CHECK-NEXT: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : index +// CHECK-NEXT: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : index // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { func @std_if(%arg0: i1, %arg1: f32) { @@ -56,9 +56,9 @@ scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %step) { %min_cmp = arith.cmpi slt, %i0, %i1 : index - %min = select %min_cmp, %i0, %i1 : index + %min = arith.select %min_cmp, %i0, %i1 : index %max_cmp = arith.cmpi sge, %i0, %i1 : index - %max = select %max_cmp, %i0, %i1 : index + %max = arith.select %max_cmp, %i0, %i1 : index %zero = arith.constant 0.0 : f32 %int_zero = arith.constant 0 : i32 %red:2 = scf.parallel (%i2) = (%min) to (%max) step (%i1) @@ -89,9 +89,9 @@ // CHECK-NEXT: scf.parallel (%[[I0:.*]], %[[I1:.*]]) = (%[[ARG0]], %[[ARG1]]) to // CHECK: (%[[ARG2]], %[[ARG3]]) step (%[[ARG4]], %[[STEP]]) { // CHECK-NEXT: %[[MIN_CMP:.*]] = arith.cmpi slt, %[[I0]], %[[I1]] : index -// CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_CMP]], %[[I0]], %[[I1]] : index +// CHECK-NEXT: %[[MIN:.*]] = arith.select %[[MIN_CMP]], %[[I0]], %[[I1]] : index // CHECK-NEXT: %[[MAX_CMP:.*]] = arith.cmpi sge, %[[I0]], %[[I1]] : index -// CHECK-NEXT: %[[MAX:.*]] = select %[[MAX_CMP]], %[[I0]], %[[I1]] : index +// CHECK-NEXT: %[[MAX:.*]] = arith.select %[[MAX_CMP]], %[[I0]], %[[I1]] : index // CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-NEXT: %[[INT_ZERO:.*]] = arith.constant 0 : i32 // CHECK-NEXT: scf.parallel (%{{.*}}) = (%[[MIN]]) to (%[[MAX]]) diff --git a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir @@ -138,7 +138,7 @@ // CHECK: } // CHECK: %[[VAL_24:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_19]] : index // CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_18]], %[[VAL_6]] : index -// CHECK: %[[VAL_26:.*]] = select %[[VAL_24]], %[[VAL_25]], %[[VAL_18]] : index +// CHECK: %[[VAL_26:.*]] = arith.select %[[VAL_24]], %[[VAL_25]], %[[VAL_18]] : index // CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_19]], %[[VAL_6]] : index // CHECK: scf.yield %[[VAL_26]], %[[VAL_27]] : index, index // CHECK: } @@ -345,7 +345,7 @@ // CHECK: } // CHECK: %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index // CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_19]], %[[VAL_6]] : index -// CHECK: %[[VAL_29:.*]] = select %[[VAL_27]], %[[VAL_28]], %[[VAL_19]] : index +// CHECK: %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_19]] : index // CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_20]], %[[VAL_6]] : index // CHECK: scf.yield %[[VAL_29]], %[[VAL_30]] : index, index // CHECK: } @@ -441,7 +441,7 @@ // CHECK: } // CHECK: %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index // CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_19]], %[[VAL_6]] : index -// CHECK: %[[VAL_29:.*]] = select %[[VAL_27]], %[[VAL_28]], %[[VAL_19]] : index +// CHECK: %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_19]] : index // CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_20]], %[[VAL_6]] : index // CHECK: scf.yield %[[VAL_29]], %[[VAL_30]] : index, index // CHECK: } @@ -528,7 +528,7 @@ // CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_23]]] : memref // CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_24]]] : memref // CHECK: %[[VAL_27:.*]] = arith.cmpi ult, %[[VAL_26]], %[[VAL_25]] : index -// CHECK: %[[VAL_28:.*]] = select %[[VAL_27]], %[[VAL_26]], %[[VAL_25]] : index +// CHECK: %[[VAL_28:.*]] = arith.select %[[VAL_27]], %[[VAL_26]], %[[VAL_25]] : index // CHECK: %[[VAL_29:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_28]] : index // CHECK: %[[VAL_30:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_28]] : index // CHECK: %[[VAL_31:.*]] = arith.andi %[[VAL_29]], %[[VAL_30]] : i1 @@ -553,10 +553,10 @@ // CHECK: } // CHECK: %[[VAL_39:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_28]] : index // CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_23]], %[[VAL_4]] : index -// CHECK: %[[VAL_41:.*]] = select %[[VAL_39]], %[[VAL_40]], %[[VAL_23]] : index +// CHECK: %[[VAL_41:.*]] = arith.select %[[VAL_39]], %[[VAL_40]], %[[VAL_23]] : index // CHECK: %[[VAL_42:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_28]] : index // CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_24]], %[[VAL_4]] : index -// CHECK: %[[VAL_44:.*]] = select %[[VAL_42]], %[[VAL_43]], %[[VAL_24]] : index +// CHECK: %[[VAL_44:.*]] = arith.select %[[VAL_42]], %[[VAL_43]], %[[VAL_24]] : index // CHECK: scf.yield %[[VAL_41]], %[[VAL_44]] : index, index // CHECK: } // CHECK: scf.for %[[VAL_45:.*]] = %[[VAL_46:.*]]#0 to %[[VAL_14]] step %[[VAL_4]] { @@ -612,7 +612,7 @@ // CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_23]]] : memref // CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_24]]] : memref // CHECK: %[[VAL_27:.*]] = arith.cmpi ult, %[[VAL_26]], %[[VAL_25]] : index -// CHECK: %[[VAL_28:.*]] = select %[[VAL_27]], %[[VAL_26]], %[[VAL_25]] : index +// CHECK: %[[VAL_28:.*]] = arith.select %[[VAL_27]], %[[VAL_26]], %[[VAL_25]] : index // CHECK: %[[VAL_29:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_28]] : index // CHECK: %[[VAL_30:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_28]] : index // CHECK: %[[VAL_31:.*]] = arith.andi %[[VAL_29]], %[[VAL_30]] : i1 @@ -625,10 +625,10 @@ // CHECK: } // CHECK: %[[VAL_35:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_28]] : index // CHECK: %[[VAL_36:.*]] = arith.addi %[[VAL_23]], %[[VAL_4]] : index -// CHECK: %[[VAL_37:.*]] = select %[[VAL_35]], %[[VAL_36]], %[[VAL_23]] : index +// CHECK: %[[VAL_37:.*]] = arith.select %[[VAL_35]], %[[VAL_36]], %[[VAL_23]] : index // CHECK: %[[VAL_38:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_28]] : index // CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_24]], %[[VAL_4]] : index -// CHECK: %[[VAL_40:.*]] = select %[[VAL_38]], %[[VAL_39]], %[[VAL_24]] : index +// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_38]], %[[VAL_39]], %[[VAL_24]] : index // CHECK: scf.yield %[[VAL_37]], %[[VAL_40]] : index, index // CHECK: } // CHECK: %[[VAL_41:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<32xf32> @@ -675,7 +675,7 @@ // CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref // CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref // CHECK: %[[VAL_28:.*]] = arith.cmpi ult, %[[VAL_27]], %[[VAL_26]] : index -// CHECK: %[[VAL_29:.*]] = select %[[VAL_28]], %[[VAL_27]], %[[VAL_26]] : index +// CHECK: %[[VAL_29:.*]] = arith.select %[[VAL_28]], %[[VAL_27]], %[[VAL_26]] : index // CHECK: %[[VAL_30:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_29]] : index // CHECK: %[[VAL_31:.*]] = arith.cmpi eq, %[[VAL_27]], %[[VAL_29]] : index // CHECK: %[[VAL_32:.*]] = arith.andi %[[VAL_30]], %[[VAL_31]] : i1 @@ -704,10 +704,10 @@ // CHECK: } // CHECK: %[[VAL_44:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_29]] : index // CHECK: %[[VAL_45:.*]] = arith.addi %[[VAL_24]], %[[VAL_5]] : index -// CHECK: %[[VAL_46:.*]] = select %[[VAL_44]], %[[VAL_45]], %[[VAL_24]] : index +// CHECK: %[[VAL_46:.*]] = arith.select %[[VAL_44]], %[[VAL_45]], %[[VAL_24]] : index // CHECK: %[[VAL_47:.*]] = arith.cmpi eq, %[[VAL_27]], %[[VAL_29]] : index // CHECK: %[[VAL_48:.*]] = arith.addi %[[VAL_25]], %[[VAL_5]] : index -// CHECK: %[[VAL_49:.*]] = select %[[VAL_47]], %[[VAL_48]], %[[VAL_25]] : index +// CHECK: %[[VAL_49:.*]] = arith.select %[[VAL_47]], %[[VAL_48]], %[[VAL_25]] : index // CHECK: scf.yield %[[VAL_46]], %[[VAL_49]] : index, index // CHECK: } // CHECK: scf.for %[[VAL_50:.*]] = %[[VAL_51:.*]]#0 to %[[VAL_15]] step %[[VAL_5]] { @@ -769,7 +769,7 @@ // CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref // CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref // CHECK: %[[VAL_28:.*]] = arith.cmpi ult, %[[VAL_27]], %[[VAL_26]] : index -// CHECK: %[[VAL_29:.*]] = select %[[VAL_28]], %[[VAL_27]], %[[VAL_26]] : index +// CHECK: %[[VAL_29:.*]] = arith.select %[[VAL_28]], %[[VAL_27]], %[[VAL_26]] : index // CHECK: %[[VAL_30:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_29]] : index // CHECK: %[[VAL_31:.*]] = arith.cmpi eq, %[[VAL_27]], %[[VAL_29]] : index // CHECK: %[[VAL_32:.*]] = arith.andi %[[VAL_30]], %[[VAL_31]] : i1 @@ -797,10 +797,10 @@ // CHECK: } // CHECK: %[[VAL_43:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_29]] : index // CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_24]], %[[VAL_5]] : index -// CHECK: %[[VAL_45:.*]] = select %[[VAL_43]], %[[VAL_44]], %[[VAL_24]] : index +// CHECK: %[[VAL_45:.*]] = arith.select %[[VAL_43]], %[[VAL_44]], %[[VAL_24]] : index // CHECK: %[[VAL_46:.*]] = arith.cmpi eq, %[[VAL_27]], %[[VAL_29]] : index // CHECK: %[[VAL_47:.*]] = arith.addi %[[VAL_25]], %[[VAL_5]] : index -// CHECK: %[[VAL_48:.*]] = select %[[VAL_46]], %[[VAL_47]], %[[VAL_25]] : index +// CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_46]], %[[VAL_47]], %[[VAL_25]] : index // CHECK: scf.yield %[[VAL_45]], %[[VAL_48]] : index, index // CHECK: } // CHECK: scf.for %[[VAL_49:.*]] = %[[VAL_50:.*]]#0 to %[[VAL_15]] step %[[VAL_5]] { @@ -914,7 +914,7 @@ // CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_25]]] : memref // CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_26]]] : memref // CHECK: %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index -// CHECK: %[[VAL_31:.*]] = select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index +// CHECK: %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index // CHECK: %[[VAL_32:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index // CHECK: %[[VAL_33:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index // CHECK: %[[VAL_34:.*]] = arith.andi %[[VAL_32]], %[[VAL_33]] : i1 @@ -945,10 +945,10 @@ // CHECK: } // CHECK: %[[VAL_50:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index // CHECK: %[[VAL_51:.*]] = arith.addi %[[VAL_25]], %[[VAL_4]] : index -// CHECK: %[[VAL_52:.*]] = select %[[VAL_50]], %[[VAL_51]], %[[VAL_25]] : index +// CHECK: %[[VAL_52:.*]] = arith.select %[[VAL_50]], %[[VAL_51]], %[[VAL_25]] : index // CHECK: %[[VAL_53:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index // CHECK: %[[VAL_54:.*]] = arith.addi %[[VAL_26]], %[[VAL_4]] : index -// CHECK: %[[VAL_55:.*]] = select %[[VAL_53]], %[[VAL_54]], %[[VAL_26]] : index +// CHECK: %[[VAL_55:.*]] = arith.select %[[VAL_53]], %[[VAL_54]], %[[VAL_26]] : index // CHECK: scf.yield %[[VAL_52]], %[[VAL_55]], %[[VAL_56:.*]] : index, index, f32 // CHECK: } // CHECK: %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_59:.*]]#0 to %[[VAL_15]] step %[[VAL_4]] iter_args(%[[VAL_60:.*]] = %[[VAL_59]]#2) -> (f32) { @@ -1025,7 +1025,7 @@ // CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_28]]] : memref // CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref // CHECK: %[[VAL_33:.*]] = arith.cmpi ult, %[[VAL_32]], %[[VAL_31]] : index -// CHECK: %[[VAL_34:.*]] = select %[[VAL_33]], %[[VAL_32]], %[[VAL_31]] : index +// CHECK: %[[VAL_34:.*]] = arith.select %[[VAL_33]], %[[VAL_32]], %[[VAL_31]] : index // CHECK: %[[VAL_35:.*]] = arith.cmpi eq, %[[VAL_31]], %[[VAL_34]] : index // CHECK: %[[VAL_36:.*]] = arith.cmpi eq, %[[VAL_32]], %[[VAL_34]] : index // CHECK: %[[VAL_37:.*]] = arith.andi %[[VAL_35]], %[[VAL_36]] : i1 @@ -1058,10 +1058,10 @@ // CHECK: } // CHECK: %[[VAL_55:.*]] = arith.cmpi eq, %[[VAL_31]], %[[VAL_34]] : index // CHECK: %[[VAL_56:.*]] = arith.addi %[[VAL_28]], %[[VAL_5]] : index -// CHECK: %[[VAL_57:.*]] = select %[[VAL_55]], %[[VAL_56]], %[[VAL_28]] : index +// CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_55]], %[[VAL_56]], %[[VAL_28]] : index // CHECK: %[[VAL_58:.*]] = arith.cmpi eq, %[[VAL_32]], %[[VAL_34]] : index // CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_29]], %[[VAL_5]] : index -// CHECK: %[[VAL_60:.*]] = select %[[VAL_58]], %[[VAL_59]], %[[VAL_29]] : index +// CHECK: %[[VAL_60:.*]] = arith.select %[[VAL_58]], %[[VAL_59]], %[[VAL_29]] : index // CHECK: scf.yield %[[VAL_57]], %[[VAL_60]], %[[VAL_61:.*]] : index, index, f32 // CHECK: } // CHECK: %[[VAL_62:.*]] = scf.for %[[VAL_63:.*]] = %[[VAL_64:.*]]#0 to %[[VAL_18]] step %[[VAL_5]] iter_args(%[[VAL_65:.*]] = %[[VAL_64]]#2) -> (f32) { @@ -1186,10 +1186,10 @@ // CHECK: } // CHECK: %[[VAL_60:.*]] = arith.cmpi eq, %[[VAL_33]], %[[VAL_32]] : index // CHECK: %[[VAL_61:.*]] = arith.addi %[[VAL_30]], %[[VAL_7]] : index -// CHECK: %[[VAL_62:.*]] = select %[[VAL_60]], %[[VAL_61]], %[[VAL_30]] : index +// CHECK: %[[VAL_62:.*]] = arith.select %[[VAL_60]], %[[VAL_61]], %[[VAL_30]] : index // CHECK: %[[VAL_63:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_32]] : index // CHECK: %[[VAL_64:.*]] = arith.addi %[[VAL_31]], %[[VAL_7]] : index -// CHECK: %[[VAL_65:.*]] = select %[[VAL_63]], %[[VAL_64]], %[[VAL_31]] : index +// CHECK: %[[VAL_65:.*]] = arith.select %[[VAL_63]], %[[VAL_64]], %[[VAL_31]] : index // CHECK: %[[VAL_66:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index // CHECK: scf.yield %[[VAL_62]], %[[VAL_65]], %[[VAL_66]] : index, index, index // CHECK: } @@ -1218,7 +1218,7 @@ // CHECK: } // CHECK: %[[VAL_84:.*]] = arith.cmpi eq, %[[VAL_74]], %[[VAL_73]] : index // CHECK: %[[VAL_85:.*]] = arith.addi %[[VAL_72]], %[[VAL_7]] : index -// CHECK: %[[VAL_86:.*]] = select %[[VAL_84]], %[[VAL_85]], %[[VAL_72]] : index +// CHECK: %[[VAL_86:.*]] = arith.select %[[VAL_84]], %[[VAL_85]], %[[VAL_72]] : index // CHECK: %[[VAL_87:.*]] = arith.addi %[[VAL_73]], %[[VAL_7]] : index // CHECK: scf.yield %[[VAL_86]], %[[VAL_87]] : index, index // CHECK: } @@ -1247,7 +1247,7 @@ // CHECK: } // CHECK: %[[VAL_106:.*]] = arith.cmpi eq, %[[VAL_96]], %[[VAL_95]] : index // CHECK: %[[VAL_107:.*]] = arith.addi %[[VAL_94]], %[[VAL_7]] : index -// CHECK: %[[VAL_108:.*]] = select %[[VAL_106]], %[[VAL_107]], %[[VAL_94]] : index +// CHECK: %[[VAL_108:.*]] = arith.select %[[VAL_106]], %[[VAL_107]], %[[VAL_94]] : index // CHECK: %[[VAL_109:.*]] = arith.addi %[[VAL_95]], %[[VAL_7]] : index // CHECK: scf.yield %[[VAL_108]], %[[VAL_109]] : index, index // CHECK: } @@ -1326,10 +1326,10 @@ // CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_34]]] : memref // CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_35]]] : memref // CHECK: %[[VAL_40:.*]] = arith.cmpi ult, %[[VAL_39]], %[[VAL_38]] : index -// CHECK: %[[VAL_41:.*]] = select %[[VAL_40]], %[[VAL_39]], %[[VAL_38]] : index +// CHECK: %[[VAL_41:.*]] = arith.select %[[VAL_40]], %[[VAL_39]], %[[VAL_38]] : index // CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_36]]] : memref // CHECK: %[[VAL_43:.*]] = arith.cmpi ult, %[[VAL_42]], %[[VAL_41]] : index -// CHECK: %[[VAL_44:.*]] = select %[[VAL_43]], %[[VAL_42]], %[[VAL_41]] : index +// CHECK: %[[VAL_44:.*]] = arith.select %[[VAL_43]], %[[VAL_42]], %[[VAL_41]] : index // CHECK: %[[VAL_45:.*]] = arith.cmpi eq, %[[VAL_38]], %[[VAL_44]] : index // CHECK: %[[VAL_46:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_44]] : index // CHECK: %[[VAL_47:.*]] = arith.andi %[[VAL_45]], %[[VAL_46]] : i1 @@ -1408,13 +1408,13 @@ // CHECK: } // CHECK: %[[VAL_99:.*]] = arith.cmpi eq, %[[VAL_38]], %[[VAL_44]] : index // CHECK: %[[VAL_100:.*]] = arith.addi %[[VAL_34]], %[[VAL_5]] : index -// CHECK: %[[VAL_101:.*]] = select %[[VAL_99]], %[[VAL_100]], %[[VAL_34]] : index +// CHECK: %[[VAL_101:.*]] = arith.select %[[VAL_99]], %[[VAL_100]], %[[VAL_34]] : index // CHECK: %[[VAL_102:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_44]] : index // CHECK: %[[VAL_103:.*]] = arith.addi %[[VAL_35]], %[[VAL_5]] : index -// CHECK: %[[VAL_104:.*]] = select %[[VAL_102]], %[[VAL_103]], %[[VAL_35]] : index +// CHECK: %[[VAL_104:.*]] = arith.select %[[VAL_102]], %[[VAL_103]], %[[VAL_35]] : index // CHECK: %[[VAL_105:.*]] = arith.cmpi eq, %[[VAL_42]], %[[VAL_44]] : index // CHECK: %[[VAL_106:.*]] = arith.addi %[[VAL_36]], %[[VAL_5]] : index -// CHECK: %[[VAL_107:.*]] = select %[[VAL_105]], %[[VAL_106]], %[[VAL_36]] : index +// CHECK: %[[VAL_107:.*]] = arith.select %[[VAL_105]], %[[VAL_106]], %[[VAL_36]] : index // CHECK: scf.yield %[[VAL_101]], %[[VAL_104]], %[[VAL_107]], %[[VAL_108:.*]] : index, index, index, f64 // CHECK: } // CHECK: %[[VAL_109:.*]]:3 = scf.while (%[[VAL_110:.*]] = %[[VAL_111:.*]]#1, %[[VAL_112:.*]] = %[[VAL_111]]#2, %[[VAL_113:.*]] = %[[VAL_111]]#3) : (index, index, f64) -> (index, index, f64) { @@ -1427,7 +1427,7 @@ // CHECK: %[[VAL_120:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_117]]] : memref // CHECK: %[[VAL_121:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_118]]] : memref // CHECK: %[[VAL_122:.*]] = arith.cmpi ult, %[[VAL_121]], %[[VAL_120]] : index -// CHECK: %[[VAL_123:.*]] = select %[[VAL_122]], %[[VAL_121]], %[[VAL_120]] : index +// CHECK: %[[VAL_123:.*]] = arith.select %[[VAL_122]], %[[VAL_121]], %[[VAL_120]] : index // CHECK: %[[VAL_124:.*]] = arith.cmpi eq, %[[VAL_120]], %[[VAL_123]] : index // CHECK: %[[VAL_125:.*]] = arith.cmpi eq, %[[VAL_121]], %[[VAL_123]] : index // CHECK: %[[VAL_126:.*]] = arith.andi %[[VAL_124]], %[[VAL_125]] : i1 @@ -1458,10 +1458,10 @@ // CHECK: } // CHECK: %[[VAL_142:.*]] = arith.cmpi eq, %[[VAL_120]], %[[VAL_123]] : index // CHECK: %[[VAL_143:.*]] = arith.addi %[[VAL_117]], %[[VAL_5]] : index -// CHECK: %[[VAL_144:.*]] = select %[[VAL_142]], %[[VAL_143]], %[[VAL_117]] : index +// CHECK: %[[VAL_144:.*]] = arith.select %[[VAL_142]], %[[VAL_143]], %[[VAL_117]] : index // CHECK: %[[VAL_145:.*]] = arith.cmpi eq, %[[VAL_121]], %[[VAL_123]] : index // CHECK: %[[VAL_146:.*]] = arith.addi %[[VAL_118]], %[[VAL_5]] : index -// CHECK: %[[VAL_147:.*]] = select %[[VAL_145]], %[[VAL_146]], %[[VAL_118]] : index +// CHECK: %[[VAL_147:.*]] = arith.select %[[VAL_145]], %[[VAL_146]], %[[VAL_118]] : index // CHECK: scf.yield %[[VAL_144]], %[[VAL_147]], %[[VAL_148:.*]] : index, index, f64 // CHECK: } // CHECK: %[[VAL_149:.*]]:3 = scf.while (%[[VAL_150:.*]] = %[[VAL_151:.*]]#0, %[[VAL_152:.*]] = %[[VAL_153:.*]]#1, %[[VAL_154:.*]] = %[[VAL_153]]#2) : (index, index, f64) -> (index, index, f64) { @@ -1474,7 +1474,7 @@ // CHECK: %[[VAL_161:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_158]]] : memref // CHECK: %[[VAL_162:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_159]]] : memref // CHECK: %[[VAL_163:.*]] = arith.cmpi ult, %[[VAL_162]], %[[VAL_161]] : index -// CHECK: %[[VAL_164:.*]] = select %[[VAL_163]], %[[VAL_162]], %[[VAL_161]] : index +// CHECK: %[[VAL_164:.*]] = arith.select %[[VAL_163]], %[[VAL_162]], %[[VAL_161]] : index // CHECK: %[[VAL_165:.*]] = arith.cmpi eq, %[[VAL_161]], %[[VAL_164]] : index // CHECK: %[[VAL_166:.*]] = arith.cmpi eq, %[[VAL_162]], %[[VAL_164]] : index // CHECK: %[[VAL_167:.*]] = arith.andi %[[VAL_165]], %[[VAL_166]] : i1 @@ -1505,10 +1505,10 @@ // CHECK: } // CHECK: %[[VAL_183:.*]] = arith.cmpi eq, %[[VAL_161]], %[[VAL_164]] : index // CHECK: %[[VAL_184:.*]] = arith.addi %[[VAL_158]], %[[VAL_5]] : index -// CHECK: %[[VAL_185:.*]] = select %[[VAL_183]], %[[VAL_184]], %[[VAL_158]] : index +// CHECK: %[[VAL_185:.*]] = arith.select %[[VAL_183]], %[[VAL_184]], %[[VAL_158]] : index // CHECK: %[[VAL_186:.*]] = arith.cmpi eq, %[[VAL_162]], %[[VAL_164]] : index // CHECK: %[[VAL_187:.*]] = arith.addi %[[VAL_159]], %[[VAL_5]] : index -// CHECK: %[[VAL_188:.*]] = select %[[VAL_186]], %[[VAL_187]], %[[VAL_159]] : index +// CHECK: %[[VAL_188:.*]] = arith.select %[[VAL_186]], %[[VAL_187]], %[[VAL_159]] : index // CHECK: scf.yield %[[VAL_185]], %[[VAL_188]], %[[VAL_189:.*]] : index, index, f64 // CHECK: } // CHECK: %[[VAL_190:.*]] = scf.for %[[VAL_191:.*]] = %[[VAL_192:.*]]#1 to %[[VAL_23]] step %[[VAL_5]] iter_args(%[[VAL_193:.*]] = %[[VAL_192]]#2) -> (f64) { @@ -1526,7 +1526,7 @@ // CHECK: %[[VAL_209:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_206]]] : memref // CHECK: %[[VAL_210:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_207]]] : memref // CHECK: %[[VAL_211:.*]] = arith.cmpi ult, %[[VAL_210]], %[[VAL_209]] : index -// CHECK: %[[VAL_212:.*]] = select %[[VAL_211]], %[[VAL_210]], %[[VAL_209]] : index +// CHECK: %[[VAL_212:.*]] = arith.select %[[VAL_211]], %[[VAL_210]], %[[VAL_209]] : index // CHECK: %[[VAL_213:.*]] = arith.cmpi eq, %[[VAL_209]], %[[VAL_212]] : index // CHECK: %[[VAL_214:.*]] = arith.cmpi eq, %[[VAL_210]], %[[VAL_212]] : index // CHECK: %[[VAL_215:.*]] = arith.andi %[[VAL_213]], %[[VAL_214]] : i1 @@ -1557,10 +1557,10 @@ // CHECK: } // CHECK: %[[VAL_231:.*]] = arith.cmpi eq, %[[VAL_209]], %[[VAL_212]] : index // CHECK: %[[VAL_232:.*]] = arith.addi %[[VAL_206]], %[[VAL_5]] : index -// CHECK: %[[VAL_233:.*]] = select %[[VAL_231]], %[[VAL_232]], %[[VAL_206]] : index +// CHECK: %[[VAL_233:.*]] = arith.select %[[VAL_231]], %[[VAL_232]], %[[VAL_206]] : index // CHECK: %[[VAL_234:.*]] = arith.cmpi eq, %[[VAL_210]], %[[VAL_212]] : index // CHECK: %[[VAL_235:.*]] = arith.addi %[[VAL_207]], %[[VAL_5]] : index -// CHECK: %[[VAL_236:.*]] = select %[[VAL_234]], %[[VAL_235]], %[[VAL_207]] : index +// CHECK: %[[VAL_236:.*]] = arith.select %[[VAL_234]], %[[VAL_235]], %[[VAL_207]] : index // CHECK: scf.yield %[[VAL_233]], %[[VAL_236]], %[[VAL_237:.*]] : index, index, f64 // CHECK: } // CHECK: %[[VAL_238:.*]] = scf.for %[[VAL_239:.*]] = %[[VAL_240:.*]]#1 to %[[VAL_21]] step %[[VAL_5]] iter_args(%[[VAL_241:.*]] = %[[VAL_240]]#2) -> (f64) { diff --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir @@ -131,7 +131,7 @@ // CHECK: } // CHECK: %[[VAL_30:.*]] = arith.cmpi eq, %[[VAL_24]], %[[VAL_23]] : index // CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_22]], %[[VAL_7]] : index -// CHECK: %[[VAL_32:.*]] = select %[[VAL_30]], %[[VAL_31]], %[[VAL_22]] : index +// CHECK: %[[VAL_32:.*]] = arith.select %[[VAL_30]], %[[VAL_31]], %[[VAL_22]] : index // CHECK: %[[VAL_33:.*]] = arith.addi %[[VAL_23]], %[[VAL_7]] : index // CHECK: scf.yield %[[VAL_32]], %[[VAL_33]] : index, index // CHECK: } @@ -239,7 +239,7 @@ // CHECK: } // CHECK: %[[VAL_32:.*]] = arith.cmpi eq, %[[VAL_22]], %[[VAL_21]] : index // CHECK: %[[VAL_33:.*]] = arith.addi %[[VAL_20]], %[[VAL_7]] : index -// CHECK: %[[VAL_34:.*]] = select %[[VAL_32]], %[[VAL_33]], %[[VAL_20]] : index +// CHECK: %[[VAL_34:.*]] = arith.select %[[VAL_32]], %[[VAL_33]], %[[VAL_20]] : index // CHECK: %[[VAL_35:.*]] = arith.addi %[[VAL_21]], %[[VAL_7]] : index // CHECK: scf.yield %[[VAL_34]], %[[VAL_35]] : index, index // CHECK: } @@ -356,7 +356,7 @@ // CHECK: } // CHECK: %[[VAL_41:.*]] = arith.cmpi eq, %[[VAL_35]], %[[VAL_34]] : index // CHECK: %[[VAL_42:.*]] = arith.addi %[[VAL_33]], %[[VAL_7]] : index -// CHECK: %[[VAL_43:.*]] = select %[[VAL_41]], %[[VAL_42]], %[[VAL_33]] : index +// CHECK: %[[VAL_43:.*]] = arith.select %[[VAL_41]], %[[VAL_42]], %[[VAL_33]] : index // CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_34]], %[[VAL_7]] : index // CHECK: scf.yield %[[VAL_43]], %[[VAL_44]] : index, index // CHECK: } @@ -375,7 +375,7 @@ // CHECK: } // CHECK: %[[VAL_50:.*]] = arith.cmpi eq, %[[VAL_24]], %[[VAL_23]] : index // CHECK: %[[VAL_51:.*]] = arith.addi %[[VAL_22]], %[[VAL_7]] : index -// CHECK: %[[VAL_52:.*]] = select %[[VAL_50]], %[[VAL_51]], %[[VAL_22]] : index +// CHECK: %[[VAL_52:.*]] = arith.select %[[VAL_50]], %[[VAL_51]], %[[VAL_22]] : index // CHECK: %[[VAL_53:.*]] = arith.addi %[[VAL_23]], %[[VAL_7]] : index // CHECK: scf.yield %[[VAL_52]], %[[VAL_53]] : index, index // CHECK: } @@ -476,7 +476,7 @@ // CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_27]]] : memref // CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_28]]] : memref // CHECK: %[[VAL_31:.*]] = arith.cmpi ult, %[[VAL_30]], %[[VAL_29]] : index -// CHECK: %[[VAL_32:.*]] = select %[[VAL_31]], %[[VAL_30]], %[[VAL_29]] : index +// CHECK: %[[VAL_32:.*]] = arith.select %[[VAL_31]], %[[VAL_30]], %[[VAL_29]] : index // CHECK: %[[VAL_33:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_32]] : index // CHECK: %[[VAL_34:.*]] = arith.cmpi eq, %[[VAL_30]], %[[VAL_32]] : index // CHECK: %[[VAL_35:.*]] = arith.andi %[[VAL_33]], %[[VAL_34]] : i1 @@ -497,7 +497,7 @@ // CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_48]]] : memref // CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_49]]] : memref // CHECK: %[[VAL_52:.*]] = arith.cmpi ult, %[[VAL_51]], %[[VAL_50]] : index -// CHECK: %[[VAL_53:.*]] = select %[[VAL_52]], %[[VAL_51]], %[[VAL_50]] : index +// CHECK: %[[VAL_53:.*]] = arith.select %[[VAL_52]], %[[VAL_51]], %[[VAL_50]] : index // CHECK: %[[VAL_54:.*]] = arith.cmpi eq, %[[VAL_50]], %[[VAL_53]] : index // CHECK: %[[VAL_55:.*]] = arith.cmpi eq, %[[VAL_51]], %[[VAL_53]] : index // CHECK: %[[VAL_56:.*]] = arith.andi %[[VAL_54]], %[[VAL_55]] : i1 @@ -522,10 +522,10 @@ // CHECK: } // CHECK: %[[VAL_64:.*]] = arith.cmpi eq, %[[VAL_50]], %[[VAL_53]] : index // CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_48]], %[[VAL_4]] : index -// CHECK: %[[VAL_66:.*]] = select %[[VAL_64]], %[[VAL_65]], %[[VAL_48]] : index +// CHECK: %[[VAL_66:.*]] = arith.select %[[VAL_64]], %[[VAL_65]], %[[VAL_48]] : index // CHECK: %[[VAL_67:.*]] = arith.cmpi eq, %[[VAL_51]], %[[VAL_53]] : index // CHECK: %[[VAL_68:.*]] = arith.addi %[[VAL_49]], %[[VAL_4]] : index -// CHECK: %[[VAL_69:.*]] = select %[[VAL_67]], %[[VAL_68]], %[[VAL_49]] : index +// CHECK: %[[VAL_69:.*]] = arith.select %[[VAL_67]], %[[VAL_68]], %[[VAL_49]] : index // CHECK: scf.yield %[[VAL_66]], %[[VAL_69]] : index, index // CHECK: } // CHECK: scf.for %[[VAL_70:.*]] = %[[VAL_71:.*]]#0 to %[[VAL_38]] step %[[VAL_4]] { @@ -566,10 +566,10 @@ // CHECK: } // CHECK: %[[VAL_92:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_32]] : index // CHECK: %[[VAL_93:.*]] = arith.addi %[[VAL_27]], %[[VAL_4]] : index -// CHECK: %[[VAL_94:.*]] = select %[[VAL_92]], %[[VAL_93]], %[[VAL_27]] : index +// CHECK: %[[VAL_94:.*]] = arith.select %[[VAL_92]], %[[VAL_93]], %[[VAL_27]] : index // CHECK: %[[VAL_95:.*]] = arith.cmpi eq, %[[VAL_30]], %[[VAL_32]] : index // CHECK: %[[VAL_96:.*]] = arith.addi %[[VAL_28]], %[[VAL_4]] : index -// CHECK: %[[VAL_97:.*]] = select %[[VAL_95]], %[[VAL_96]], %[[VAL_28]] : index +// CHECK: %[[VAL_97:.*]] = arith.select %[[VAL_95]], %[[VAL_96]], %[[VAL_28]] : index // CHECK: scf.yield %[[VAL_94]], %[[VAL_97]] : index, index // CHECK: } // CHECK: scf.for %[[VAL_98:.*]] = %[[VAL_99:.*]]#0 to %[[VAL_18]] step %[[VAL_4]] { @@ -641,7 +641,7 @@ // CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_27]]] : memref // CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_28]]] : memref // CHECK: %[[VAL_31:.*]] = arith.cmpi ult, %[[VAL_30]], %[[VAL_29]] : index -// CHECK: %[[VAL_32:.*]] = select %[[VAL_31]], %[[VAL_30]], %[[VAL_29]] : index +// CHECK: %[[VAL_32:.*]] = arith.select %[[VAL_31]], %[[VAL_30]], %[[VAL_29]] : index // CHECK: %[[VAL_33:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_32]] : index // CHECK: %[[VAL_34:.*]] = arith.cmpi eq, %[[VAL_30]], %[[VAL_32]] : index // CHECK: %[[VAL_35:.*]] = arith.andi %[[VAL_33]], %[[VAL_34]] : i1 @@ -662,7 +662,7 @@ // CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_48]]] : memref // CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_49]]] : memref // CHECK: %[[VAL_52:.*]] = arith.cmpi ult, %[[VAL_51]], %[[VAL_50]] : index -// CHECK: %[[VAL_53:.*]] = select %[[VAL_52]], %[[VAL_51]], %[[VAL_50]] : index +// CHECK: %[[VAL_53:.*]] = arith.select %[[VAL_52]], %[[VAL_51]], %[[VAL_50]] : index // CHECK: %[[VAL_54:.*]] = arith.cmpi eq, %[[VAL_50]], %[[VAL_53]] : index // CHECK: %[[VAL_55:.*]] = arith.cmpi eq, %[[VAL_51]], %[[VAL_53]] : index // CHECK: %[[VAL_56:.*]] = arith.andi %[[VAL_54]], %[[VAL_55]] : i1 @@ -675,20 +675,20 @@ // CHECK: } // CHECK: %[[VAL_60:.*]] = arith.cmpi eq, %[[VAL_50]], %[[VAL_53]] : index // CHECK: %[[VAL_61:.*]] = arith.addi %[[VAL_48]], %[[VAL_4]] : index -// CHECK: %[[VAL_62:.*]] = select %[[VAL_60]], %[[VAL_61]], %[[VAL_48]] : index +// CHECK: %[[VAL_62:.*]] = arith.select %[[VAL_60]], %[[VAL_61]], %[[VAL_48]] : index // CHECK: %[[VAL_63:.*]] = arith.cmpi eq, %[[VAL_51]], %[[VAL_53]] : index // CHECK: %[[VAL_64:.*]] = arith.addi %[[VAL_49]], %[[VAL_4]] : index -// CHECK: %[[VAL_65:.*]] = select %[[VAL_63]], %[[VAL_64]], %[[VAL_49]] : index +// CHECK: %[[VAL_65:.*]] = arith.select %[[VAL_63]], %[[VAL_64]], %[[VAL_49]] : index // CHECK: scf.yield %[[VAL_62]], %[[VAL_65]] : index, index // CHECK: } // CHECK: } else { // CHECK: } // CHECK: %[[VAL_66:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_32]] : index // CHECK: %[[VAL_67:.*]] = arith.addi %[[VAL_27]], %[[VAL_4]] : index -// CHECK: %[[VAL_68:.*]] = select %[[VAL_66]], %[[VAL_67]], %[[VAL_27]] : index +// CHECK: %[[VAL_68:.*]] = arith.select %[[VAL_66]], %[[VAL_67]], %[[VAL_27]] : index // CHECK: %[[VAL_69:.*]] = arith.cmpi eq, %[[VAL_30]], %[[VAL_32]] : index // CHECK: %[[VAL_70:.*]] = arith.addi %[[VAL_28]], %[[VAL_4]] : index -// CHECK: %[[VAL_71:.*]] = select %[[VAL_69]], %[[VAL_70]], %[[VAL_28]] : index +// CHECK: %[[VAL_71:.*]] = arith.select %[[VAL_69]], %[[VAL_70]], %[[VAL_28]] : index // CHECK: scf.yield %[[VAL_68]], %[[VAL_71]] : index, index // CHECK: } // CHECK: %[[VAL_72:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<32x16xf32> @@ -759,7 +759,7 @@ // CHECK: } // CHECK: %[[VAL_43:.*]] = arith.cmpi eq, %[[VAL_35]], %[[VAL_34]] : index // CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_33]], %[[VAL_7]] : index -// CHECK: %[[VAL_45:.*]] = select %[[VAL_43]], %[[VAL_44]], %[[VAL_33]] : index +// CHECK: %[[VAL_45:.*]] = arith.select %[[VAL_43]], %[[VAL_44]], %[[VAL_33]] : index // CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_34]], %[[VAL_7]] : index // CHECK: scf.yield %[[VAL_45]], %[[VAL_46]] : index, index // CHECK: } @@ -784,7 +784,7 @@ // CHECK: } // CHECK: %[[VAL_58:.*]] = arith.cmpi eq, %[[VAL_24]], %[[VAL_23]] : index // CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_22]], %[[VAL_7]] : index -// CHECK: %[[VAL_60:.*]] = select %[[VAL_58]], %[[VAL_59]], %[[VAL_22]] : index +// CHECK: %[[VAL_60:.*]] = arith.select %[[VAL_58]], %[[VAL_59]], %[[VAL_22]] : index // CHECK: %[[VAL_61:.*]] = arith.addi %[[VAL_23]], %[[VAL_7]] : index // CHECK: scf.yield %[[VAL_60]], %[[VAL_61]] : index, index // CHECK: } @@ -1155,10 +1155,10 @@ // CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_57]]] : memref // CHECK: %[[VAL_62:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_58]]] : memref // CHECK: %[[VAL_63:.*]] = arith.cmpi ult, %[[VAL_62]], %[[VAL_61]] : index -// CHECK: %[[VAL_64:.*]] = select %[[VAL_63]], %[[VAL_62]], %[[VAL_61]] : index +// CHECK: %[[VAL_64:.*]] = arith.select %[[VAL_63]], %[[VAL_62]], %[[VAL_61]] : index // CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_59]]] : memref // CHECK: %[[VAL_66:.*]] = arith.cmpi ult, %[[VAL_65]], %[[VAL_64]] : index -// CHECK: %[[VAL_67:.*]] = select %[[VAL_66]], %[[VAL_65]], %[[VAL_64]] : index +// CHECK: %[[VAL_67:.*]] = arith.select %[[VAL_66]], %[[VAL_65]], %[[VAL_64]] : index // CHECK: %[[VAL_68:.*]] = arith.cmpi eq, %[[VAL_61]], %[[VAL_67]] : index // CHECK: %[[VAL_69:.*]] = arith.cmpi eq, %[[VAL_62]], %[[VAL_67]] : index // CHECK: %[[VAL_70:.*]] = arith.andi %[[VAL_68]], %[[VAL_69]] : i1 @@ -1201,13 +1201,13 @@ // CHECK: } // CHECK: %[[VAL_98:.*]] = arith.cmpi eq, %[[VAL_61]], %[[VAL_67]] : index // CHECK: %[[VAL_99:.*]] = arith.addi %[[VAL_57]], %[[VAL_7]] : index -// CHECK: %[[VAL_100:.*]] = select %[[VAL_98]], %[[VAL_99]], %[[VAL_57]] : index +// CHECK: %[[VAL_100:.*]] = arith.select %[[VAL_98]], %[[VAL_99]], %[[VAL_57]] : index // CHECK: %[[VAL_101:.*]] = arith.cmpi eq, %[[VAL_62]], %[[VAL_67]] : index // CHECK: %[[VAL_102:.*]] = arith.addi %[[VAL_58]], %[[VAL_7]] : index -// CHECK: %[[VAL_103:.*]] = select %[[VAL_101]], %[[VAL_102]], %[[VAL_58]] : index +// CHECK: %[[VAL_103:.*]] = arith.select %[[VAL_101]], %[[VAL_102]], %[[VAL_58]] : index // CHECK: %[[VAL_104:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_67]] : index // CHECK: %[[VAL_105:.*]] = arith.addi %[[VAL_59]], %[[VAL_7]] : index -// CHECK: %[[VAL_106:.*]] = select %[[VAL_104]], %[[VAL_105]], %[[VAL_59]] : index +// CHECK: %[[VAL_106:.*]] = arith.select %[[VAL_104]], %[[VAL_105]], %[[VAL_59]] : index // CHECK: scf.yield %[[VAL_100]], %[[VAL_103]], %[[VAL_106]], %[[VAL_107:.*]] : index, index, index, f32 // CHECK: } // CHECK: %[[VAL_108:.*]]:3 = scf.while (%[[VAL_109:.*]] = %[[VAL_110:.*]]#0, %[[VAL_111:.*]] = %[[VAL_110]]#1, %[[VAL_112:.*]] = %[[VAL_110]]#3) : (index, index, f32) -> (index, index, f32) { @@ -1220,7 +1220,7 @@ // CHECK: %[[VAL_119:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_116]]] : memref // CHECK: %[[VAL_120:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_117]]] : memref // CHECK: %[[VAL_121:.*]] = arith.cmpi ult, %[[VAL_120]], %[[VAL_119]] : index -// CHECK: %[[VAL_122:.*]] = select %[[VAL_121]], %[[VAL_120]], %[[VAL_119]] : index +// CHECK: %[[VAL_122:.*]] = arith.select %[[VAL_121]], %[[VAL_120]], %[[VAL_119]] : index // CHECK: %[[VAL_123:.*]] = arith.cmpi eq, %[[VAL_119]], %[[VAL_122]] : index // CHECK: %[[VAL_124:.*]] = arith.cmpi eq, %[[VAL_120]], %[[VAL_122]] : index // CHECK: %[[VAL_125:.*]] = arith.andi %[[VAL_123]], %[[VAL_124]] : i1 @@ -1237,10 +1237,10 @@ // CHECK: } // CHECK: %[[VAL_133:.*]] = arith.cmpi eq, %[[VAL_119]], %[[VAL_122]] : index // CHECK: %[[VAL_134:.*]] = arith.addi %[[VAL_116]], %[[VAL_7]] : index -// CHECK: %[[VAL_135:.*]] = select %[[VAL_133]], %[[VAL_134]], %[[VAL_116]] : index +// CHECK: %[[VAL_135:.*]] = arith.select %[[VAL_133]], %[[VAL_134]], %[[VAL_116]] : index // CHECK: %[[VAL_136:.*]] = arith.cmpi eq, %[[VAL_120]], %[[VAL_122]] : index // CHECK: %[[VAL_137:.*]] = arith.addi %[[VAL_117]], %[[VAL_7]] : index -// CHECK: %[[VAL_138:.*]] = select %[[VAL_136]], %[[VAL_137]], %[[VAL_117]] : index +// CHECK: %[[VAL_138:.*]] = arith.select %[[VAL_136]], %[[VAL_137]], %[[VAL_117]] : index // CHECK: scf.yield %[[VAL_135]], %[[VAL_138]], %[[VAL_139:.*]] : index, index, f32 // CHECK: } // CHECK: %[[VAL_140:.*]] = scf.for %[[VAL_141:.*]] = %[[VAL_142:.*]]#2 to %[[VAL_46]] step %[[VAL_7]] iter_args(%[[VAL_143:.*]] = %[[VAL_144:.*]]#2) -> (f32) { @@ -1266,7 +1266,7 @@ // CHECK: } // CHECK: %[[VAL_158:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_33]] : index // CHECK: %[[VAL_159:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index -// CHECK: %[[VAL_160:.*]] = select %[[VAL_158]], %[[VAL_159]], %[[VAL_32]] : index +// CHECK: %[[VAL_160:.*]] = arith.select %[[VAL_158]], %[[VAL_159]], %[[VAL_32]] : index // CHECK: %[[VAL_161:.*]] = arith.addi %[[VAL_33]], %[[VAL_7]] : index // CHECK: scf.yield %[[VAL_160]], %[[VAL_161]] : index, index // CHECK: } diff --git a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir @@ -152,7 +152,7 @@ // CHECK: } // CHECK: %[[VAL_35:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_28]] : index // CHECK: %[[VAL_36:.*]] = arith.addi %[[VAL_27]], %[[VAL_9]] : index -// CHECK: %[[VAL_37:.*]] = select %[[VAL_35]], %[[VAL_36]], %[[VAL_27]] : index +// CHECK: %[[VAL_37:.*]] = arith.select %[[VAL_35]], %[[VAL_36]], %[[VAL_27]] : index // CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_28]], %[[VAL_9]] : index // CHECK: scf.yield %[[VAL_37]], %[[VAL_38]] : index, index // CHECK: } @@ -270,7 +270,7 @@ // CHECK: } // CHECK: %[[VAL_35:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_24]] : index // CHECK: %[[VAL_36:.*]] = arith.addi %[[VAL_23]], %[[VAL_8]] : index -// CHECK: %[[VAL_37:.*]] = select %[[VAL_35]], %[[VAL_36]], %[[VAL_23]] : index +// CHECK: %[[VAL_37:.*]] = arith.select %[[VAL_35]], %[[VAL_36]], %[[VAL_23]] : index // CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_24]], %[[VAL_8]] : index // CHECK: scf.yield %[[VAL_37]], %[[VAL_38]] : index, index // CHECK: } @@ -396,7 +396,7 @@ // CHECK: } // CHECK: %[[VAL_45:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_38]] : index // CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_37]], %[[VAL_9]] : index -// CHECK: %[[VAL_47:.*]] = select %[[VAL_45]], %[[VAL_46]], %[[VAL_37]] : index +// CHECK: %[[VAL_47:.*]] = arith.select %[[VAL_45]], %[[VAL_46]], %[[VAL_37]] : index // CHECK: %[[VAL_48:.*]] = arith.addi %[[VAL_38]], %[[VAL_9]] : index // CHECK: scf.yield %[[VAL_47]], %[[VAL_48]] : index, index // CHECK: } @@ -415,7 +415,7 @@ // CHECK: } // CHECK: %[[VAL_54:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_27]] : index // CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_26]], %[[VAL_9]] : index -// CHECK: %[[VAL_56:.*]] = select %[[VAL_54]], %[[VAL_55]], %[[VAL_26]] : index +// CHECK: %[[VAL_56:.*]] = arith.select %[[VAL_54]], %[[VAL_55]], %[[VAL_26]] : index // CHECK: %[[VAL_57:.*]] = arith.addi %[[VAL_27]], %[[VAL_9]] : index // CHECK: scf.yield %[[VAL_56]], %[[VAL_57]] : index, index // CHECK: } @@ -541,7 +541,7 @@ // CHECK: } // CHECK: %[[VAL_37:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_22]] : index // CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_21]], %[[VAL_8]] : index -// CHECK: %[[VAL_39:.*]] = select %[[VAL_37]], %[[VAL_38]], %[[VAL_21]] : index +// CHECK: %[[VAL_39:.*]] = arith.select %[[VAL_37]], %[[VAL_38]], %[[VAL_21]] : index // CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_22]], %[[VAL_8]] : index // CHECK: scf.yield %[[VAL_39]], %[[VAL_40]] : index, index // CHECK: } @@ -670,7 +670,7 @@ // CHECK: } // CHECK: %[[VAL_46:.*]] = arith.cmpi eq, %[[VAL_40]], %[[VAL_39]] : index // CHECK: %[[VAL_47:.*]] = arith.addi %[[VAL_38]], %[[VAL_9]] : index -// CHECK: %[[VAL_48:.*]] = select %[[VAL_46]], %[[VAL_47]], %[[VAL_38]] : index +// CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_46]], %[[VAL_47]], %[[VAL_38]] : index // CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_39]], %[[VAL_9]] : index // CHECK: scf.yield %[[VAL_48]], %[[VAL_49]] : index, index // CHECK: } @@ -692,7 +692,7 @@ // CHECK: } // CHECK: %[[VAL_56:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_25]] : index // CHECK: %[[VAL_57:.*]] = arith.addi %[[VAL_24]], %[[VAL_9]] : index -// CHECK: %[[VAL_58:.*]] = select %[[VAL_56]], %[[VAL_57]], %[[VAL_24]] : index +// CHECK: %[[VAL_58:.*]] = arith.select %[[VAL_56]], %[[VAL_57]], %[[VAL_24]] : index // CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_25]], %[[VAL_9]] : index // CHECK: scf.yield %[[VAL_58]], %[[VAL_59]] : index, index // CHECK: } @@ -827,7 +827,7 @@ // CHECK: } // CHECK: %[[VAL_46:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_35]] : index // CHECK: %[[VAL_47:.*]] = arith.addi %[[VAL_34]], %[[VAL_8]] : index -// CHECK: %[[VAL_48:.*]] = select %[[VAL_46]], %[[VAL_47]], %[[VAL_34]] : index +// CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_46]], %[[VAL_47]], %[[VAL_34]] : index // CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_35]], %[[VAL_8]] : index // CHECK: scf.yield %[[VAL_48]], %[[VAL_49]] : index, index // CHECK: } @@ -850,7 +850,7 @@ // CHECK: } // CHECK: %[[VAL_57:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_24]] : index // CHECK: %[[VAL_58:.*]] = arith.addi %[[VAL_23]], %[[VAL_8]] : index -// CHECK: %[[VAL_59:.*]] = select %[[VAL_57]], %[[VAL_58]], %[[VAL_23]] : index +// CHECK: %[[VAL_59:.*]] = arith.select %[[VAL_57]], %[[VAL_58]], %[[VAL_23]] : index // CHECK: %[[VAL_60:.*]] = arith.addi %[[VAL_24]], %[[VAL_8]] : index // CHECK: scf.yield %[[VAL_59]], %[[VAL_60]] : index, index // CHECK: } @@ -992,7 +992,7 @@ // CHECK: } // CHECK: %[[VAL_56:.*]] = arith.cmpi eq, %[[VAL_50]], %[[VAL_49]] : index // CHECK: %[[VAL_57:.*]] = arith.addi %[[VAL_48]], %[[VAL_9]] : index -// CHECK: %[[VAL_58:.*]] = select %[[VAL_56]], %[[VAL_57]], %[[VAL_48]] : index +// CHECK: %[[VAL_58:.*]] = arith.select %[[VAL_56]], %[[VAL_57]], %[[VAL_48]] : index // CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_49]], %[[VAL_9]] : index // CHECK: scf.yield %[[VAL_58]], %[[VAL_59]] : index, index // CHECK: } @@ -1011,7 +1011,7 @@ // CHECK: } // CHECK: %[[VAL_65:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_38]] : index // CHECK: %[[VAL_66:.*]] = arith.addi %[[VAL_37]], %[[VAL_9]] : index -// CHECK: %[[VAL_67:.*]] = select %[[VAL_65]], %[[VAL_66]], %[[VAL_37]] : index +// CHECK: %[[VAL_67:.*]] = arith.select %[[VAL_65]], %[[VAL_66]], %[[VAL_37]] : index // CHECK: %[[VAL_68:.*]] = arith.addi %[[VAL_38]], %[[VAL_9]] : index // CHECK: scf.yield %[[VAL_67]], %[[VAL_68]] : index, index // CHECK: } @@ -1034,7 +1034,7 @@ // CHECK: } // CHECK: %[[VAL_76:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_27]] : index // CHECK: %[[VAL_77:.*]] = arith.addi %[[VAL_26]], %[[VAL_9]] : index -// CHECK: %[[VAL_78:.*]] = select %[[VAL_76]], %[[VAL_77]], %[[VAL_26]] : index +// CHECK: %[[VAL_78:.*]] = arith.select %[[VAL_76]], %[[VAL_77]], %[[VAL_26]] : index // CHECK: %[[VAL_79:.*]] = arith.addi %[[VAL_27]], %[[VAL_9]] : index // CHECK: scf.yield %[[VAL_78]], %[[VAL_79]] : index, index // CHECK: } diff --git a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir @@ -194,7 +194,7 @@ // CHECK: } // CHECK: %[[VAL_26:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_19]] : index // CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_18]], %[[VAL_6]] : index -// CHECK: %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index +// CHECK: %[[VAL_28:.*]] = arith.select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index // CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_19]], %[[VAL_6]] : index // CHECK: scf.yield %[[VAL_28]], %[[VAL_29]] : index, index // CHECK: } @@ -255,7 +255,7 @@ // CHECK: } // CHECK: %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_19]] : index // CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_18]], %[[VAL_6]] : index -// CHECK: %[[VAL_29:.*]] = select %[[VAL_27]], %[[VAL_28]], %[[VAL_18]] : index +// CHECK: %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_18]] : index // CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_19]], %[[VAL_6]] : index // CHECK: scf.yield %[[VAL_29]], %[[VAL_30]] : index, index // CHECK: } diff --git a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir @@ -58,7 +58,7 @@ // CHECK: } // CHECK: %[[VAL_26:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_19]] : index // CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_18]], %[[VAL_6]] : index -// CHECK: %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index +// CHECK: %[[VAL_28:.*]] = arith.select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index // CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_19]], %[[VAL_6]] : index // CHECK: scf.yield %[[VAL_28]], %[[VAL_29]] : index, index // CHECK: } @@ -120,7 +120,7 @@ // CHECK: } // CHECK: %[[VAL_28:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index // CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_19]], %[[VAL_6]] : index -// CHECK: %[[VAL_30:.*]] = select %[[VAL_28]], %[[VAL_29]], %[[VAL_19]] : index +// CHECK: %[[VAL_30:.*]] = arith.select %[[VAL_28]], %[[VAL_29]], %[[VAL_19]] : index // CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_20]], %[[VAL_6]] : index // CHECK: scf.yield %[[VAL_30]], %[[VAL_31]] : index, index // CHECK: } @@ -321,7 +321,7 @@ // CHECK: } // CHECK: %[[VAL_26:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_19]] : index // CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_18]], %[[VAL_6]] : index -// CHECK: %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index +// CHECK: %[[VAL_28:.*]] = arith.select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index // CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_19]], %[[VAL_6]] : index // CHECK: scf.yield %[[VAL_28]], %[[VAL_29]] : index, index // CHECK: } @@ -381,7 +381,7 @@ // CHECK: } // CHECK: %[[VAL_26:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_19]] : index // CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_18]], %[[VAL_6]] : index -// CHECK: %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index +// CHECK: %[[VAL_28:.*]] = arith.select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index // CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_19]], %[[VAL_6]] : index // CHECK: scf.yield %[[VAL_28]], %[[VAL_29]] : index, index // CHECK: } diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir @@ -97,7 +97,7 @@ // CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_40]]] : memref // CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_41]]] : memref // CHECK: %[[VAL_45:.*]] = arith.cmpi ult, %[[VAL_44]], %[[VAL_43]] : index -// CHECK: %[[VAL_46:.*]] = select %[[VAL_45]], %[[VAL_44]], %[[VAL_43]] : index +// CHECK: %[[VAL_46:.*]] = arith.select %[[VAL_45]], %[[VAL_44]], %[[VAL_43]] : index // CHECK: %[[VAL_47:.*]] = arith.cmpi eq, %[[VAL_43]], %[[VAL_46]] : index // CHECK: %[[VAL_48:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_46]] : index // CHECK: %[[VAL_49:.*]] = arith.andi %[[VAL_47]], %[[VAL_48]] : i1 @@ -131,10 +131,10 @@ // CHECK: } // CHECK: %[[VAL_69:.*]] = arith.cmpi eq, %[[VAL_43]], %[[VAL_46]] : index // CHECK: %[[VAL_70:.*]] = arith.addi %[[VAL_40]], %[[VAL_4]] : index -// CHECK: %[[VAL_71:.*]] = select %[[VAL_69]], %[[VAL_70]], %[[VAL_40]] : index +// CHECK: %[[VAL_71:.*]] = arith.select %[[VAL_69]], %[[VAL_70]], %[[VAL_40]] : index // CHECK: %[[VAL_72:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_46]] : index // CHECK: %[[VAL_73:.*]] = arith.addi %[[VAL_41]], %[[VAL_4]] : index -// CHECK: %[[VAL_74:.*]] = select %[[VAL_72]], %[[VAL_73]], %[[VAL_41]] : index +// CHECK: %[[VAL_74:.*]] = arith.select %[[VAL_72]], %[[VAL_73]], %[[VAL_41]] : index // CHECK: scf.yield %[[VAL_71]], %[[VAL_74]], %[[VAL_75:.*]] : index, index, index // CHECK: } // CHECK: sparse_tensor.compress %[[VAL_8]], %[[VAL_19]], %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_76:.*]]#2 : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>>, memref, memref, memref, memref, index diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir @@ -192,7 +192,7 @@ // CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_34]]] : memref // CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_35]]] : memref // CHECK: %[[VAL_38:.*]] = arith.cmpi ult, %[[VAL_37]], %[[VAL_36]] : index -// CHECK: %[[VAL_39:.*]] = select %[[VAL_38]], %[[VAL_37]], %[[VAL_36]] : index +// CHECK: %[[VAL_39:.*]] = arith.select %[[VAL_38]], %[[VAL_37]], %[[VAL_36]] : index // CHECK: memref.store %[[VAL_39]], %[[VAL_23]]{{\[}}%[[VAL_2]]] : memref // CHECK: %[[VAL_40:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_39]] : index // CHECK: %[[VAL_41:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_39]] : index @@ -214,7 +214,7 @@ // CHECK: %[[VAL_57:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_55]]] : memref // CHECK: %[[VAL_58:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_56]]] : memref // CHECK: %[[VAL_59:.*]] = arith.cmpi ult, %[[VAL_58]], %[[VAL_57]] : index -// CHECK: %[[VAL_60:.*]] = select %[[VAL_59]], %[[VAL_58]], %[[VAL_57]] : index +// CHECK: %[[VAL_60:.*]] = arith.select %[[VAL_59]], %[[VAL_58]], %[[VAL_57]] : index // CHECK: memref.store %[[VAL_60]], %[[VAL_23]]{{\[}}%[[VAL_3]]] : memref // CHECK: %[[VAL_61:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index // CHECK: %[[VAL_62:.*]] = arith.cmpi eq, %[[VAL_58]], %[[VAL_60]] : index @@ -236,7 +236,7 @@ // CHECK: %[[VAL_80:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_77]]] : memref // CHECK: %[[VAL_81:.*]] = memref.load %[[VAL_21]]{{\[}}%[[VAL_78]]] : memref // CHECK: %[[VAL_82:.*]] = arith.cmpi ult, %[[VAL_81]], %[[VAL_80]] : index -// CHECK: %[[VAL_83:.*]] = select %[[VAL_82]], %[[VAL_81]], %[[VAL_80]] : index +// CHECK: %[[VAL_83:.*]] = arith.select %[[VAL_82]], %[[VAL_81]], %[[VAL_80]] : index // CHECK: memref.store %[[VAL_83]], %[[VAL_23]]{{\[}}%[[VAL_4]]] : memref // CHECK: %[[VAL_84:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_83]] : index // CHECK: %[[VAL_85:.*]] = arith.cmpi eq, %[[VAL_81]], %[[VAL_83]] : index @@ -252,10 +252,10 @@ // CHECK: } // CHECK: %[[VAL_92:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_83]] : index // CHECK: %[[VAL_93:.*]] = arith.addi %[[VAL_77]], %[[VAL_3]] : index -// CHECK: %[[VAL_94:.*]] = select %[[VAL_92]], %[[VAL_93]], %[[VAL_77]] : index +// CHECK: %[[VAL_94:.*]] = arith.select %[[VAL_92]], %[[VAL_93]], %[[VAL_77]] : index // CHECK: %[[VAL_95:.*]] = arith.cmpi eq, %[[VAL_81]], %[[VAL_83]] : index // CHECK: %[[VAL_96:.*]] = arith.addi %[[VAL_78]], %[[VAL_3]] : index -// CHECK: %[[VAL_97:.*]] = select %[[VAL_95]], %[[VAL_96]], %[[VAL_78]] : index +// CHECK: %[[VAL_97:.*]] = arith.select %[[VAL_95]], %[[VAL_96]], %[[VAL_78]] : index // CHECK: scf.yield %[[VAL_94]], %[[VAL_97]], %[[VAL_98:.*]] : index, index, i32 // CHECK: } // CHECK: sparse_tensor.lex_insert %[[VAL_8]], %[[VAL_23]], %[[VAL_99:.*]]#2 : tensor, memref, i32 @@ -263,20 +263,20 @@ // CHECK: } // CHECK: %[[VAL_100:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index // CHECK: %[[VAL_101:.*]] = arith.addi %[[VAL_55]], %[[VAL_3]] : index -// CHECK: %[[VAL_102:.*]] = select %[[VAL_100]], %[[VAL_101]], %[[VAL_55]] : index +// CHECK: %[[VAL_102:.*]] = arith.select %[[VAL_100]], %[[VAL_101]], %[[VAL_55]] : index // CHECK: %[[VAL_103:.*]] = arith.cmpi eq, %[[VAL_58]], %[[VAL_60]] : index // CHECK: %[[VAL_104:.*]] = arith.addi %[[VAL_56]], %[[VAL_3]] : index -// CHECK: %[[VAL_105:.*]] = select %[[VAL_103]], %[[VAL_104]], %[[VAL_56]] : index +// CHECK: %[[VAL_105:.*]] = arith.select %[[VAL_103]], %[[VAL_104]], %[[VAL_56]] : index // CHECK: scf.yield %[[VAL_102]], %[[VAL_105]] : index, index // CHECK: } // CHECK: } else { // CHECK: } // CHECK: %[[VAL_106:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_39]] : index // CHECK: %[[VAL_107:.*]] = arith.addi %[[VAL_34]], %[[VAL_3]] : index -// CHECK: %[[VAL_108:.*]] = select %[[VAL_106]], %[[VAL_107]], %[[VAL_34]] : index +// CHECK: %[[VAL_108:.*]] = arith.select %[[VAL_106]], %[[VAL_107]], %[[VAL_34]] : index // CHECK: %[[VAL_109:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_39]] : index // CHECK: %[[VAL_110:.*]] = arith.addi %[[VAL_35]], %[[VAL_3]] : index -// CHECK: %[[VAL_111:.*]] = select %[[VAL_109]], %[[VAL_110]], %[[VAL_35]] : index +// CHECK: %[[VAL_111:.*]] = arith.select %[[VAL_109]], %[[VAL_110]], %[[VAL_35]] : index // CHECK: scf.yield %[[VAL_108]], %[[VAL_111]] : index, index // CHECK: } // CHECK: %[[VAL_112:.*]] = sparse_tensor.load %[[VAL_8]] hasInserts : tensor @@ -354,7 +354,7 @@ // CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_41]]] : memref // CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_42]]] : memref // CHECK: %[[VAL_46:.*]] = arith.cmpi ult, %[[VAL_45]], %[[VAL_44]] : index -// CHECK: %[[VAL_47:.*]] = select %[[VAL_46]], %[[VAL_45]], %[[VAL_44]] : index +// CHECK: %[[VAL_47:.*]] = arith.select %[[VAL_46]], %[[VAL_45]], %[[VAL_44]] : index // CHECK: %[[VAL_48:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_47]] : index // CHECK: %[[VAL_49:.*]] = arith.cmpi eq, %[[VAL_45]], %[[VAL_47]] : index // CHECK: %[[VAL_50:.*]] = arith.andi %[[VAL_48]], %[[VAL_49]] : i1 @@ -388,10 +388,10 @@ // CHECK: } // CHECK: %[[VAL_70:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_47]] : index // CHECK: %[[VAL_71:.*]] = arith.addi %[[VAL_41]], %[[VAL_3]] : index -// CHECK: %[[VAL_72:.*]] = select %[[VAL_70]], %[[VAL_71]], %[[VAL_41]] : index +// CHECK: %[[VAL_72:.*]] = arith.select %[[VAL_70]], %[[VAL_71]], %[[VAL_41]] : index // CHECK: %[[VAL_73:.*]] = arith.cmpi eq, %[[VAL_45]], %[[VAL_47]] : index // CHECK: %[[VAL_74:.*]] = arith.addi %[[VAL_42]], %[[VAL_3]] : index -// CHECK: %[[VAL_75:.*]] = select %[[VAL_73]], %[[VAL_74]], %[[VAL_42]] : index +// CHECK: %[[VAL_75:.*]] = arith.select %[[VAL_73]], %[[VAL_74]], %[[VAL_42]] : index // CHECK: scf.yield %[[VAL_72]], %[[VAL_75]], %[[VAL_76:.*]] : index, index, index // CHECK: } // CHECK: sparse_tensor.compress %[[VAL_9]], %[[VAL_20]], %[[VAL_25]], %[[VAL_26]], %[[VAL_27]], %[[VAL_77:.*]]#2 : tensor>, memref, memref, memref, memref, index diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir @@ -51,7 +51,7 @@ // CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_33]]] : memref // CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_34]]] : memref // CHECK: %[[VAL_38:.*]] = arith.cmpi ult, %[[VAL_37]], %[[VAL_36]] : index -// CHECK: %[[VAL_39:.*]] = select %[[VAL_38]], %[[VAL_37]], %[[VAL_36]] : index +// CHECK: %[[VAL_39:.*]] = arith.select %[[VAL_38]], %[[VAL_37]], %[[VAL_36]] : index // CHECK: %[[VAL_40:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_39]] : index // CHECK: %[[VAL_41:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_39]] : index // CHECK: %[[VAL_42:.*]] = arith.andi %[[VAL_40]], %[[VAL_41]] : i1 @@ -82,10 +82,10 @@ // CHECK: } // CHECK: %[[VAL_58:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_39]] : index // CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_33]], %[[VAL_8]] : index -// CHECK: %[[VAL_60:.*]] = select %[[VAL_58]], %[[VAL_59]], %[[VAL_33]] : index +// CHECK: %[[VAL_60:.*]] = arith.select %[[VAL_58]], %[[VAL_59]], %[[VAL_33]] : index // CHECK: %[[VAL_61:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_39]] : index // CHECK: %[[VAL_62:.*]] = arith.addi %[[VAL_34]], %[[VAL_8]] : index -// CHECK: %[[VAL_63:.*]] = select %[[VAL_61]], %[[VAL_62]], %[[VAL_34]] : index +// CHECK: %[[VAL_63:.*]] = arith.select %[[VAL_61]], %[[VAL_62]], %[[VAL_34]] : index // CHECK: scf.yield %[[VAL_60]], %[[VAL_63]], %[[VAL_64:.*]] : index, index, f64 // CHECK: } // CHECK: %[[VAL_65:.*]] = vector.insertelement %[[VAL_66:.*]]#2, %[[VAL_3]]{{\[}}%[[VAL_6]] : index] : vector<8xf64> @@ -94,7 +94,7 @@ // CHECK: %[[VAL_71:.*]] = vector.create_mask %[[VAL_70]] : vector<8xi1> // CHECK: %[[VAL_72:.*]] = vector.maskedload %[[VAL_11]]{{\[}}%[[VAL_68]]], %[[VAL_71]], %[[VAL_3]] : memref, vector<8xi1>, vector<8xf64> into vector<8xf64> // CHECK: %[[VAL_73:.*]] = arith.addf %[[VAL_69]], %[[VAL_72]] : vector<8xf64> -// CHECK: %[[VAL_74:.*]] = select %[[VAL_71]], %[[VAL_73]], %[[VAL_69]] : vector<8xi1>, vector<8xf64> +// CHECK: %[[VAL_74:.*]] = arith.select %[[VAL_71]], %[[VAL_73]], %[[VAL_69]] : vector<8xi1>, vector<8xf64> // CHECK: scf.yield %[[VAL_74]] : vector<8xf64> // CHECK: } // CHECK: %[[VAL_75:.*]] = scf.for %[[VAL_76:.*]] = %[[VAL_66]]#1 to %[[VAL_25]] step %[[VAL_4]] iter_args(%[[VAL_77:.*]] = %[[VAL_78:.*]]) -> (vector<8xf64>) { @@ -102,7 +102,7 @@ // CHECK: %[[VAL_80:.*]] = vector.create_mask %[[VAL_79]] : vector<8xi1> // CHECK: %[[VAL_81:.*]] = vector.maskedload %[[VAL_14]]{{\[}}%[[VAL_76]]], %[[VAL_80]], %[[VAL_3]] : memref, vector<8xi1>, vector<8xf64> into vector<8xf64> // CHECK: %[[VAL_82:.*]] = arith.addf %[[VAL_77]], %[[VAL_81]] : vector<8xf64> -// CHECK: %[[VAL_83:.*]] = select %[[VAL_80]], %[[VAL_82]], %[[VAL_77]] : vector<8xi1>, vector<8xf64> +// CHECK: %[[VAL_83:.*]] = arith.select %[[VAL_80]], %[[VAL_82]], %[[VAL_77]] : vector<8xi1>, vector<8xf64> // CHECK: scf.yield %[[VAL_83]] : vector<8xf64> // CHECK: } // CHECK: %[[VAL_84:.*]] = vector.reduction "add", %[[VAL_85:.*]] : vector<8xf64> into f64 diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: mlir-opt %s -std-bufferize | FileCheck %s - -// CHECK-LABEL: func @select( -// CHECK-SAME: %[[PRED:.*]]: i1, -// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor, -// CHECK-SAME: %[[FALSE_VAL:.*]]: tensor) -> tensor { -// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : memref -// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : memref -// CHECK: %[[RET_MEMREF:.*]] = select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref -// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[RET_MEMREF]] : memref -// CHECK: return %[[RET]] : tensor -func @select(%arg0: i1, %arg1: tensor, %arg2: tensor) -> tensor { - %0 = select %arg0, %arg1, %arg2 : tensor - return %0 : tensor -} diff --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir --- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir +++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir @@ -73,8 +73,8 @@ func @cond_br_same_successor_insert_select( %cond : i1, %a : i32, %b : i32, %c : tensor<2xi32>, %d : tensor<2xi32> ) -> (i32, tensor<2xi32>) { - // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]] - // CHECK: %[[RES2:.*]] = select %[[COND]], %[[ARG2]], %[[ARG3]] + // CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[ARG0]], %[[ARG1]] + // CHECK: %[[RES2:.*]] = arith.select %[[COND]], %[[ARG2]], %[[ARG3]] // CHECK: return %[[RES]], %[[RES2]] cond_br %cond, ^bb1(%a, %c : i32, tensor<2xi32>), ^bb1(%b, %d : i32, tensor<2xi32>) @@ -105,8 +105,8 @@ // CHECK-LABEL: func @cond_br_passthrough( // CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1 func @cond_br_passthrough(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) { - // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG2]] - // CHECK: %[[RES2:.*]] = select %[[COND]], %[[ARG1]], %[[ARG2]] + // CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[ARG0]], %[[ARG2]] + // CHECK: %[[RES2:.*]] = arith.select %[[COND]], %[[ARG1]], %[[ARG2]] // CHECK: return %[[RES]], %[[RES2]] cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg2, %arg2 : i32, i32) diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: @select_same_val // CHECK: return %arg1 func @select_same_val(%arg0: i1, %arg1: i64) -> i64 { - %0 = select %arg0, %arg1, %arg1 : i64 + %0 = arith.select %arg0, %arg1, %arg1 : i64 return %0 : i64 } @@ -13,7 +13,7 @@ // CHECK: return %arg1 func @select_cmp_eq_select(%arg0: i64, %arg1: i64) -> i64 { %0 = arith.cmpi eq, %arg0, %arg1 : i64 - %1 = select %0, %arg0, %arg1 : i64 + %1 = arith.select %0, %arg0, %arg1 : i64 return %1 : i64 } @@ -23,7 +23,7 @@ // CHECK: return %arg0 func @select_cmp_ne_select(%arg0: i64, %arg1: i64) -> i64 { %0 = arith.cmpi ne, %arg0, %arg1 : i64 - %1 = select %0, %arg0, %arg1 : i64 + %1 = arith.select %0, %arg0, %arg1 : i64 return %1 : i64 } @@ -35,7 +35,7 @@ func @select_extui(%arg0: i1) -> i64 { %c0_i64 = arith.constant 0 : i64 %c1_i64 = arith.constant 1 : i64 - %res = select %arg0, %c1_i64, %c0_i64 : i64 + %res = arith.select %arg0, %c1_i64, %c0_i64 : i64 return %res : i64 } @@ -47,7 +47,7 @@ func @select_extui2(%arg0: i1) -> i64 { %c0_i64 = arith.constant 0 : i64 %c1_i64 = arith.constant 1 : i64 - %res = select %arg0, %c0_i64, %c1_i64 : i64 + %res = arith.select %arg0, %c0_i64, %c1_i64 : i64 return %res : i64 } @@ -58,7 +58,7 @@ func @select_extui_i1(%arg0: i1) -> i1 { %c0_i1 = arith.constant false %c1_i1 = arith.constant true - %res = select %arg0, %c1_i1, %c0_i1 : i1 + %res = arith.select %arg0, %c1_i1, %c0_i1 : i1 return %res : i1 } @@ -93,7 +93,7 @@ func @selToNot(%arg0: i1) -> i1 { %true = arith.constant true %false = arith.constant false - %res = select %arg0, %false, %true : i1 + %res = arith.select %arg0, %false, %true : i1 return %res : i1 } @@ -105,6 +105,6 @@ // CHECK-NEXT: %[[res:.+]] = arith.ori %[[condtrue]], %[[condfalse]] : i1 // CHECK: return %[[res]] func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 { - %res = select %arg0, %arg1, %arg2 : i1 + %res = arith.select %arg0, %arg1, %arg2 : i1 return %res : i1 } diff --git a/mlir/test/Dialect/Standard/expand-tanh.mlir b/mlir/test/Dialect/Standard/expand-tanh.mlir --- a/mlir/test/Dialect/Standard/expand-tanh.mlir +++ b/mlir/test/Dialect/Standard/expand-tanh.mlir @@ -19,5 +19,5 @@ // CHECK: %[[DIVISOR2:.+]] = arith.addf %[[EXP2]], %[[ONE]] : f32 // CHECK: %[[RES2:.+]] = arith.divf %[[DIVIDEND2]], %[[DIVISOR2]] : f32 // CHECK: %[[COND:.+]] = arith.cmpf oge, %arg0, %[[ZERO]] : f32 -// CHECK: %[[RESULT:.+]] = select %[[COND]], %[[RES1]], %[[RES2]] : f32 +// CHECK: %[[RESULT:.+]] = arith.select %[[COND]], %[[RES1]], %[[RES2]] : f32 // CHECK: return %[[RESULT]] diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -825,10 +825,10 @@ // CHECK: %[[c1:.*]] = arith.constant 1 : index // CHECK: %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1> // CHECK: %[[T1:.*]] = arith.cmpi slt, %[[c0]], %[[A]] : index -// CHECK: %[[T2:.*]] = select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1> +// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1> // CHECK: %[[T4:.*]] = arith.cmpi slt, %[[c1]], %[[A]] : index -// CHECK: %[[T5:.*]] = select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1> +// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1> // CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1> // CHECK: return %[[T6]] : vector<2x3xi1> @@ -848,13 +848,13 @@ // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index // CHECK: %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1> // CHECK: %[[T1:.*]] = arith.cmpi slt, %[[c0]], %[[B]] : index -// CHECK: %[[T2:.*]] = select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1> +// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1> // CHECK: %[[T4:.*]] = arith.cmpi slt, %[[c0]], %[[A]] : index -// CHECK: %[[T5:.*]] = select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1> +// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1> // CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1> // CHECK: %[[T7:.*]] = arith.cmpi slt, %[[c1]], %[[A]] : index -// CHECK: %[[T8:.*]] = select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1> +// CHECK: %[[T8:.*]] = arith.select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1> // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1> // CHECK: return %[[T9]] : vector<2x1x7xi1> diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -355,10 +355,10 @@ // CHECK: %[[VT5:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> // CHECK: %[[VT6:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> // CHECK: %[[VT7:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> -// CHECK: %[[SEL0:.*]] = select %[[CMP0]], %[[VT0]], %[[VT4]] : vector<2x2xi1>, vector<2x2xf32> -// CHECK: %[[SEL1:.*]] = select %[[CMP1]], %[[VT1]], %[[VT5]] : vector<2x2xi1>, vector<2x2xf32> -// CHECK: %[[SEL2:.*]] = select %[[CMP2]], %[[VT2]], %[[VT6]] : vector<2x2xi1>, vector<2x2xf32> -// CHECK: %[[SEL3:.*]] = select %[[CMP3]], %[[VT3]], %[[VT7]] : vector<2x2xi1>, vector<2x2xf32> +// CHECK: %[[SEL0:.*]] = arith.select %[[CMP0]], %[[VT0]], %[[VT4]] : vector<2x2xi1>, vector<2x2xf32> +// CHECK: %[[SEL1:.*]] = arith.select %[[CMP1]], %[[VT1]], %[[VT5]] : vector<2x2xi1>, vector<2x2xf32> +// CHECK: %[[SEL2:.*]] = arith.select %[[CMP2]], %[[VT2]], %[[VT6]] : vector<2x2xi1>, vector<2x2xf32> +// CHECK: %[[SEL3:.*]] = arith.select %[[CMP3]], %[[VT3]], %[[VT7]] : vector<2x2xi1>, vector<2x2xf32> // CHECK: vector.transfer_write %[[SEL0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> // CHECK: vector.transfer_write %[[SEL1]], %[[ARG0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> // CHECK: vector.transfer_write %[[SEL2]], %[[ARG0]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> @@ -372,7 +372,7 @@ // Vector transfer split pattern only support single user right now. %2 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> %3 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> - %4 = select %cond, %2, %3 : vector<4x4xi1>, vector<4x4xf32> + %4 = arith.select %cond, %2, %3 : vector<4x4xi1>, vector<4x4xf32> vector.transfer_write %4, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32> return } @@ -510,12 +510,12 @@ // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> // CHECK: select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32> // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32> - %2 = select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32> + %2 = arith.select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32> // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> // CHECK: select %arg4, %12, %{{.*}} : vector<4xf32> // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32> - %3 = select %arg4, %arg3, %arg2 : vector<1x4xf32> + %3 = arith.select %arg4, %arg3, %arg2 : vector<1x4xf32> return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32> } diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -61,20 +61,20 @@ %tci1 = arith.constant dense<1> : tensor<42xi1> %vci1 = arith.constant dense<1> : vector<42xi1> - // CHECK: %{{.*}} = select %{{.*}}, %arg3, %arg3 : index - %21 = select %true, %idx, %idx : index + // CHECK: %{{.*}} = arith.select %{{.*}}, %arg3, %arg3 : index + %21 = arith.select %true, %idx, %idx : index - // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<42xi1>, tensor<42xi32> - %22 = select %tci1, %tci32, %tci32 : tensor<42 x i1>, tensor<42 x i32> + // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<42xi1>, tensor<42xi32> + %22 = arith.select %tci1, %tci32, %tci32 : tensor<42 x i1>, tensor<42 x i32> - // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : vector<42xi1>, vector<42xi32> - %23 = select %vci1, %vci32, %vci32 : vector<42 x i1>, vector<42 x i32> + // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : vector<42xi1>, vector<42xi32> + %23 = arith.select %vci1, %vci32, %vci32 : vector<42 x i1>, vector<42 x i32> - // CHECK: %{{.*}} = select %{{.*}}, %arg3, %arg3 : index - %24 = "std.select"(%true, %idx, %idx) : (i1, index, index) -> index + // CHECK: %{{.*}} = arith.select %{{.*}}, %arg3, %arg3 : index + %24 = "arith.select"(%true, %idx, %idx) : (i1, index, index) -> index - // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<42xi32> - %25 = std.select %true, %tci32, %tci32 : tensor<42 x i32> + // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<42xi32> + %25 = arith.select %true, %tci32, %tci32 : tensor<42 x i32> %64 = arith.constant dense<0.> : vector<4 x f32> %tcf32 = arith.constant dense<0.> : tensor<42 x f32> diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -55,7 +55,7 @@ ^bb0(%cond : i32, %t : i32, %f : i32): // expected-error@+2 {{different type than prior uses}} // expected-note@-2 {{prior use here}} - %r = select %cond, %t, %f : i32 + %r = arith.select %cond, %t, %f : i32 } // ----- @@ -63,7 +63,7 @@ func @func_with_ops(i32, i32, i32) { ^bb0(%cond : i32, %t : i32, %f : i32): // expected-error@+1 {{op operand #0 must be bool-like}} - %r = "std.select"(%cond, %t, %f) : (i32, i32, i32) -> i32 + %r = "arith.select"(%cond, %t, %f) : (i32, i32, i32) -> i32 } // ----- @@ -75,7 +75,7 @@ // message. In final state the error should refer to mismatch in true_value and // false_value. // expected-error@+1 {{type}} - %r = "std.select"(%cond, %t, %f) : (i1, i32, i64) -> i32 + %r = "arith.select"(%cond, %t, %f) : (i1, i32, i64) -> i32 } // ----- @@ -83,7 +83,7 @@ func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) { ^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>): // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} - %r = "std.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32> + %r = "arith.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32> } // ----- @@ -91,7 +91,7 @@ func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) { ^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>): // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} - %r = "std.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32> + %r = "arith.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32> } // ----- diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \ +// RUN: mlir-opt %s -linalg-bufferize \ // RUN: -arith-bufferize -tensor-bufferize -func-bufferize \ // RUN: -finalizing-bufferize -buffer-deallocation -convert-linalg-to-llvm \ // RUN: -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-elementwise.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-elementwise.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-elementwise.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-elementwise.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize \ +// RUN: mlir-opt %s -convert-elementwise-to-linalg \ // RUN: -arith-bufferize -linalg-bufferize -tensor-bufferize \ // RUN: -func-bufferize -buffer-deallocation -convert-linalg-to-loops \ // RUN: -convert-linalg-to-llvm --convert-memref-to-llvm -convert-std-to-llvm \ diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \ +// RUN: mlir-opt %s -linalg-bufferize \ // RUN: -arith-bufferize -tensor-bufferize -func-bufferize \ // RUN: -finalizing-bufferize -buffer-deallocation -convert-linalg-to-llvm \ // RUN: -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \ +// RUN: mlir-opt %s -linalg-bufferize \ // RUN: -arith-bufferize -tensor-bufferize -func-bufferize \ // RUN: -finalizing-bufferize -buffer-deallocation \ // RUN: -convert-linalg-to-loops -convert-scf-to-std -convert-linalg-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \ +// RUN: mlir-opt %s -linalg-bufferize \ // RUN: -arith-bufferize -tensor-bufferize -func-bufferize \ // RUN: -finalizing-bufferize -buffer-deallocation \ // RUN: -convert-linalg-to-loops -convert-scf-to-std -convert-linalg-to-llvm --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-subtensor-insert.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-subtensor-insert.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-subtensor-insert.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-subtensor-insert.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \ +// RUN: mlir-opt %s -linalg-bufferize \ // RUN: -arith-bufferize -tensor-bufferize -func-bufferize \ // RUN: -finalizing-bufferize -buffer-deallocation \ // RUN: -convert-linalg-to-loops -convert-scf-to-std -convert-linalg-to-llvm --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-tensor-e2e.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-tensor-e2e.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-tensor-e2e.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-tensor-e2e.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -arith-bufferize -std-bufferize -linalg-bufferize \ +// RUN: mlir-opt %s -arith-bufferize -linalg-bufferize \ // RUN: -tensor-bufferize -func-bufferize -finalizing-bufferize -buffer-deallocation -convert-linalg-to-loops \ // RUN: -convert-linalg-to-llvm --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-tensor-matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-tensor-matmul.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-tensor-matmul.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-tensor-matmul.mlir @@ -1,5 +1,5 @@ // UNSUPPORTED: asan -// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -arith-bufferize \ +// RUN: mlir-opt %s -linalg-bufferize -arith-bufferize \ // RUN: -tensor-bufferize -func-bufferize -finalizing-bufferize -buffer-deallocation -convert-linalg-to-loops -convert-scf-to-std \ // RUN: -convert-linalg-to-llvm -lower-affine -convert-scf-to-std --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ @@ -7,7 +7,7 @@ // RUN: | FileCheck %s // RUN: mlir-opt %s -linalg-tile="tile-sizes=1,2,3" -linalg-bufferize \ -// RUN: -scf-bufferize -std-bufferize -arith-bufferize -tensor-bufferize \ +// RUN: -scf-bufferize -arith-bufferize -tensor-bufferize \ // RUN: -func-bufferize \ // RUN: -finalizing-bufferize -convert-linalg-to-loops -convert-scf-to-std -convert-scf-to-std \ // RUN: -convert-linalg-to-llvm -lower-affine -convert-scf-to-std --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py @@ -130,7 +130,7 @@ f'convert-scf-to-std,' f'func-bufferize,' f'arith-bufferize,' - f'builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),' + f'builtin.func(tensor-bufferize,finalizing-bufferize),' f'convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}},' f'lower-affine,' f'convert-memref-to-llvm,' diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py @@ -120,7 +120,7 @@ f'convert-scf-to-std,' f'func-bufferize,' f'arith-bufferize,' - f'builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),' + f'builtin.func(tensor-bufferize,finalizing-bufferize),' f'convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}},' f'lower-affine,' f'convert-memref-to-llvm,' diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py @@ -72,7 +72,7 @@ f'convert-scf-to-std,' f'func-bufferize,' f'arith-bufferize,' - f'builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),' + f'builtin.func(tensor-bufferize,finalizing-bufferize),' f'convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}},' f'lower-affine,' f'convert-memref-to-llvm,' diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py @@ -80,7 +80,7 @@ f'convert-scf-to-std,' f'func-bufferize,' f'arith-bufferize,' - f'builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),' + f'builtin.func(tensor-bufferize,finalizing-bufferize),' f'convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}},' f'lower-affine,' f'convert-memref-to-llvm,' diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py @@ -178,7 +178,7 @@ f'convert-scf-to-std,' f'func-bufferize,' f'arith-bufferize,' - f'builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),' + f'builtin.func(tensor-bufferize,finalizing-bufferize),' f'convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}},' f'lower-affine,' f'convert-memref-to-llvm,' diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py @@ -145,7 +145,7 @@ f"convert-scf-to-std," f"func-bufferize," f"arith-bufferize," - f"builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize)," + f"builtin.func(tensor-bufferize,finalizing-bufferize)," f"convert-vector-to-llvm{{reassociate-fp-reductions=1 enable-index-optimizations=1}}," f"lower-affine," f"convert-memref-to-llvm," diff --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir --- a/mlir/test/Transforms/canonicalize-block-merge.mlir +++ b/mlir/test/Transforms/canonicalize-block-merge.mlir @@ -55,7 +55,7 @@ // CHECK-LABEL: func @mismatch_operands // CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 func @mismatch_operands(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 { - // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]] + // CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[ARG0]], %[[ARG1]] // CHECK: return %[[RES]] cond_br %cond, ^bb1, ^bb2 @@ -71,8 +71,8 @@ // CHECK-LABEL: func @mismatch_operands_matching_arguments( // CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 func @mismatch_operands_matching_arguments(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32, i32) { - // CHECK: %[[RES0:.*]] = select %[[COND]], %[[ARG1]], %[[ARG0]] - // CHECK: %[[RES1:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]] + // CHECK: %[[RES0:.*]] = arith.select %[[COND]], %[[ARG1]], %[[ARG0]] + // CHECK: %[[RES1:.*]] = arith.select %[[COND]], %[[ARG0]], %[[ARG1]] // CHECK: return %[[RES1]], %[[RES0]] cond_br %cond, ^bb1(%arg1 : i32), ^bb2(%arg0 : i32) diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -602,14 +602,14 @@ %c0 = arith.constant 0 : index %1 = arith.cmpi slt, %0, %c0 : index %2 = arith.addi %0, %c42 : index - %3 = select %1, %2, %0 : index + %3 = arith.select %1, %2, %0 : index %c43 = arith.constant 43 : index %c42_0 = arith.constant 42 : index %4 = arith.remsi %c43, %c42_0 : index %c0_1 = arith.constant 0 : index %5 = arith.cmpi slt, %4, %c0_1 : index %6 = arith.addi %4, %c42_0 : index - %7 = select %5, %6, %4 : index + %7 = arith.select %5, %6, %4 : index return %3, %7 : index, index } @@ -628,20 +628,20 @@ %c-1 = arith.constant -1 : index %0 = arith.cmpi slt, %c-43, %c0 : index %1 = arith.subi %c-1, %c-43 : index - %2 = select %0, %1, %c-43 : index + %2 = arith.select %0, %1, %c-43 : index %3 = arith.divsi %2, %c42 : index %4 = arith.subi %c-1, %3 : index - %5 = select %0, %4, %3 : index + %5 = arith.select %0, %4, %3 : index %c43 = arith.constant 43 : index %c42_0 = arith.constant 42 : index %c0_1 = arith.constant 0 : index %c-1_2 = arith.constant -1 : index %6 = arith.cmpi slt, %c43, %c0_1 : index %7 = arith.subi %c-1_2, %c43 : index - %8 = select %6, %7, %c43 : index + %8 = arith.select %6, %7, %c43 : index %9 = arith.divsi %8, %c42_0 : index %10 = arith.subi %c-1_2, %9 : index - %11 = select %6, %10, %9 : index + %11 = arith.select %6, %10, %9 : index return %5, %11 : index, index } @@ -660,11 +660,11 @@ %0 = arith.cmpi sle, %c-43, %c0 : index %1 = arith.subi %c0, %c-43 : index %2 = arith.subi %c-43, %c1 : index - %3 = select %0, %1, %2 : index + %3 = arith.select %0, %1, %2 : index %4 = arith.divsi %3, %c42 : index %5 = arith.subi %c0, %4 : index %6 = arith.addi %4, %c1 : index - %7 = select %0, %5, %6 : index + %7 = arith.select %0, %5, %6 : index // CHECK-DAG: %c2 = arith.constant 2 : index %c43 = arith.constant 43 : index %c42_0 = arith.constant 42 : index @@ -673,11 +673,11 @@ %8 = arith.cmpi sle, %c43, %c0_1 : index %9 = arith.subi %c0_1, %c43 : index %10 = arith.subi %c43, %c1_2 : index - %11 = select %8, %9, %10 : index + %11 = arith.select %8, %9, %10 : index %12 = arith.divsi %11, %c42_0 : index %13 = arith.subi %c0_1, %12 : index %14 = arith.addi %12, %c1_2 : index - %15 = select %8, %13, %14 : index + %15 = arith.select %8, %13, %14 : index // CHECK-NEXT: return %c-1, %c2 return %7, %15 : index, index diff --git a/mlir/test/Transforms/parametric-tiling.mlir b/mlir/test/Transforms/parametric-tiling.mlir --- a/mlir/test/Transforms/parametric-tiling.mlir +++ b/mlir/test/Transforms/parametric-tiling.mlir @@ -41,11 +41,11 @@ // Upper bound for the inner loop min(%i + %step, %c44). // COMMON: %[[stepped:.*]] = arith.addi %[[i]], %[[step]] // COMMON-NEXT: arith.cmpi slt, %c44, %[[stepped]] - // COMMON-NEXT: %[[ub:.*]] = select {{.*}}, %c44, %[[stepped]] + // COMMON-NEXT: %[[ub:.*]] = arith.select {{.*}}, %c44, %[[stepped]] // // TILE_74: %[[stepped2:.*]] = arith.addi %[[j]], %[[step2]] // TILE_74-NEXT: arith.cmpi slt, %c44, %[[stepped2]] - // TILE_74-NEXT: %[[ub2:.*]] = select {{.*}}, %c44, %[[stepped2]] + // TILE_74-NEXT: %[[ub2:.*]] = arith.select {{.*}}, %c44, %[[stepped2]] // Created inner scf. // COMMON:scf.for %[[ii:.*]] = %[[i]] to %[[ub:.*]] step %c1 @@ -109,10 +109,10 @@ // Upper bound for the inner loop min(%i + %step, %c44). // COMMON: %[[stepped:.*]] = arith.addi %[[i]], %[[step]] // COMMON-NEXT: arith.cmpi slt, %c44, %[[stepped]] - // COMMON-NEXT: %[[ub:.*]] = select {{.*}}, %c44, %[[stepped]] + // COMMON-NEXT: %[[ub:.*]] = arith.select {{.*}}, %c44, %[[stepped]] // TILE_74: %[[stepped2:.*]] = arith.addi %[[j]], %[[step2]] // TILE_74-NEXT: arith.cmpi slt, %[[i]], %[[stepped2]] - // TILE_74-NEXT: %[[ub2:.*]] = select {{.*}}, %[[i]], %[[stepped2]] + // TILE_74-NEXT: %[[ub2:.*]] = arith.select {{.*}}, %[[i]], %[[stepped2]] // // Created inner scf. // COMMON:scf.for %[[ii:.*]] = %[[i]] to %[[ub:.*]] step %c1 diff --git a/mlir/test/Transforms/sccp-callgraph.mlir b/mlir/test/Transforms/sccp-callgraph.mlir --- a/mlir/test/Transforms/sccp-callgraph.mlir +++ b/mlir/test/Transforms/sccp-callgraph.mlir @@ -262,11 +262,11 @@ // CHECK-LABEL: func private @unreferenced_private_function func private @unreferenced_private_function() -> i32 { - // CHECK: %[[RES:.*]] = select + // CHECK: %[[RES:.*]] = arith.select // CHECK: return %[[RES]] : i32 %true = arith.constant true %cst0 = arith.constant 0 : i32 %cst1 = arith.constant 1 : i32 - %result = select %true, %cst0, %cst1 : i32 + %result = arith.select %true, %cst0, %cst1 : i32 return %result : i32 } diff --git a/mlir/test/Transforms/sccp.mlir b/mlir/test/Transforms/sccp.mlir --- a/mlir/test/Transforms/sccp.mlir +++ b/mlir/test/Transforms/sccp.mlir @@ -9,7 +9,7 @@ %cond = arith.constant true %cst_1 = arith.constant 1 : i32 - %select = select %cond, %cst_1, %arg0 : i32 + %select = arith.select %cond, %cst_1, %arg0 : i32 return %select : i32 } diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -22,7 +22,6 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" @@ -62,7 +61,6 @@ arith::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); - mlir::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); } diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -65,7 +65,7 @@ private: // Return the target shape based on op type. static Optional> getShape(Operation *op) { - if (isa(op)) + if (isa(op)) return SmallVector(2, 2); if (isa(op)) return SmallVector(3, 2); @@ -96,8 +96,8 @@ } static LogicalResult filter(Operation *op) { - return success(isa(op)); + return success(isa(op)); } };