diff --git a/flang/include/flang/Lower/CharacterRuntime.h b/flang/include/flang/Lower/CharacterRuntime.h --- a/flang/include/flang/Lower/CharacterRuntime.h +++ b/flang/include/flang/Lower/CharacterRuntime.h @@ -18,7 +18,7 @@ /// Generate call to a character comparison for two ssa-values of type /// `boxchar`. mlir::Value genBoxCharCompare(AbstractConverter &converter, mlir::Location loc, - mlir::CmpIPredicate cmp, mlir::Value lhs, + mlir::arith::CmpIPredicate cmp, mlir::Value lhs, mlir::Value rhs); /// Generate call to a character comparison op for two unboxed variables. There @@ -26,9 +26,9 @@ /// reference to its buffer (`ref>`) and its LEN type parameter (some /// integral type). mlir::Value genRawCharCompare(AbstractConverter &converter, mlir::Location loc, - mlir::CmpIPredicate cmp, mlir::Value lhsBuff, - mlir::Value lhsLen, mlir::Value rhsBuff, - mlir::Value rhsLen); + mlir::arith::CmpIPredicate cmp, + mlir::Value lhsBuff, mlir::Value lhsLen, + mlir::Value rhsBuff, mlir::Value rhsLen); } // namespace lower } // namespace Fortran diff --git a/flang/include/flang/Lower/Support/Utils.h b/flang/include/flang/Lower/Support/Utils.h --- a/flang/include/flang/Lower/Support/Utils.h +++ b/flang/include/flang/Lower/Support/Utils.h @@ -30,9 +30,9 @@ } namespace fir { -/// Return the integer value of a ConstantOp. -inline std::int64_t toInt(mlir::ConstantOp cop) { - return cop.getValue().cast().getValue().getSExtValue(); +/// Return the integer value of a arith::ConstantOp. +inline std::int64_t toInt(mlir::arith::ConstantOp cop) { + return cop.value().cast().getValue().getSExtValue(); } } // namespace fir diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.h b/flang/include/flang/Optimizer/Dialect/FIROps.h --- a/flang/include/flang/Optimizer/Dialect/FIROps.h +++ b/flang/include/flang/Optimizer/Dialect/FIROps.h @@ -10,6 +10,7 @@ #define FORTRAN_OPTIMIZER_DIALECT_FIROPS_H #include "flang/Optimizer/Dialect/FIRType.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -23,7 +24,7 @@ class RealAttr; void buildCmpCOp(mlir::OpBuilder &builder, mlir::OperationState &result, - mlir::CmpFPredicate predicate, mlir::Value lhs, + mlir::arith::CmpFPredicate predicate, mlir::Value lhs, mlir::Value rhs); unsigned getCaseArgumentOffset(llvm::ArrayRef cases, unsigned dest); diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -310,7 +310,7 @@ argument. The length of the !fir.char type is ignored. ```mlir - fir.char_convert %1 for %2 to %3 : !fir.ref>, i32, + fir.char_convert %1 for %2 to %3 : !fir.ref>, i32, !fir.ref> ``` @@ -2544,7 +2544,7 @@ let printer = "printCmpcOp(p, *this);"; - let builders = [OpBuilder<(ins "mlir::CmpFPredicate":$predicate, + let builders = [OpBuilder<(ins "mlir::arith::CmpFPredicate":$predicate, "mlir::Value":$lhs, "mlir::Value":$rhs), [{ buildCmpCOp($_builder, $_state, predicate, lhs, rhs); }]>]; @@ -2554,12 +2554,12 @@ return "predicate"; } - CmpFPredicate getPredicate() { - return (CmpFPredicate)(*this)->getAttrOfType( + arith::CmpFPredicate getPredicate() { + return (arith::CmpFPredicate)(*this)->getAttrOfType( getPredicateAttrName()).getInt(); } - static CmpFPredicate getPredicateByName(llvm::StringRef name); + static arith::CmpFPredicate getPredicateByName(llvm::StringRef name); }]; } @@ -2676,9 +2676,9 @@ operations with a single FMA operation. ```mlir - %98 = mulf %96, %97 : f32 + %98 = arith.mulf %96, %97 : f32 %99 = fir.no_reassoc %98 : f32 - %a0 = addf %99, %95 : f32 + %a0 = arith.addf %99, %95 : f32 ``` }]; diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h --- a/flang/include/flang/Optimizer/Support/InitFIR.h +++ b/flang/include/flang/Optimizer/Support/InitFIR.h @@ -13,6 +13,7 @@ #ifndef FORTRAN_OPTIMIZER_SUPPORT_INITFIR_H #define FORTRAN_OPTIMIZER_SUPPORT_INITFIR_H +#include "flang/Optimizer/CodeGen/CodeGen.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "mlir/Conversion/Passes.h" #include "mlir/Dialect/Affine/Passes.h" @@ -27,7 +28,8 @@ #define FLANG_NONCODEGEN_DIALECT_LIST \ mlir::AffineDialect, FIROpsDialect, mlir::acc::OpenACCDialect, \ mlir::omp::OpenMPDialect, mlir::scf::SCFDialect, \ - mlir::StandardOpsDialect, mlir::vector::VectorDialect + mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect, \ + mlir::vector::VectorDialect // The definitive list of dialects used by flang. #define FLANG_DIALECT_LIST \ diff --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h --- a/flang/include/flang/Optimizer/Support/Utils.h +++ b/flang/include/flang/Optimizer/Support/Utils.h @@ -17,9 +17,9 @@ #include "mlir/IR/BuiltinAttributes.h" namespace fir { -/// Return the integer value of a ConstantOp. -inline std::int64_t toInt(mlir::ConstantOp cop) { - return cop.getValue().cast().getValue().getSExtValue(); +/// Return the integer value of a arith::ConstantOp. +inline std::int64_t toInt(mlir::arith::ConstantOp cop) { + return cop.value().cast().getValue().getSExtValue(); } } // namespace fir diff --git a/flang/include/flang/Optimizer/Transforms/RewritePatterns.td b/flang/include/flang/Optimizer/Transforms/RewritePatterns.td --- a/flang/include/flang/Optimizer/Transforms/RewritePatterns.td +++ b/flang/include/flang/Optimizer/Transforms/RewritePatterns.td @@ -15,6 +15,7 @@ #define FORTRAN_FIR_REWRITE_PATTERNS include "mlir/IR/OpBase.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" include "mlir/Dialect/StandardOps/IR/Ops.td" include "flang/Optimizer/Dialect/FIROps.td" @@ -46,12 +47,12 @@ ,(SmallerWidthPred $arg, $irm)]>; def createConstantOp - : NativeCodeCall<"$_builder.create" + : NativeCodeCall<"$_builder.create" "($_loc, $_builder.getIndexType(), " "rewriter.getIndexAttr($1.dyn_cast().getInt()))">; def ForwardConstantConvertPattern - : Pat<(fir_ConvertOp:$res (ConstantOp:$cnt $attr)), + : Pat<(fir_ConvertOp:$res (Arith_ConstantOp:$cnt $attr)), (createConstantOp $res, $attr), [(IndexTypePred $res) ,(IntegerTypePred $cnt)]>; diff --git a/flang/lib/Lower/CharacterExpr.cpp b/flang/lib/Lower/CharacterExpr.cpp --- a/flang/lib/Lower/CharacterExpr.cpp +++ b/flang/lib/Lower/CharacterExpr.cpp @@ -268,7 +268,8 @@ // Pad if needed. if (!compileTimeSameLength) { auto one = builder.createIntegerConstant(loc, lhs.getLen().getType(), 1); - auto maxPadding = builder.create(loc, lhs.getLen(), one); + auto maxPadding = + builder.create(loc, lhs.getLen(), one); createPadding(lhs, copyCount, maxPadding); } } @@ -276,17 +277,17 @@ fir::CharBoxValue Fortran::lower::CharacterExprHelper::createConcatenate( const fir::CharBoxValue &lhs, const fir::CharBoxValue &rhs) { mlir::Value len = - builder.create(loc, lhs.getLen(), rhs.getLen()); + builder.create(loc, lhs.getLen(), rhs.getLen()); auto temp = createTemp(getCharacterType(rhs), len); createCopy(temp, lhs, lhs.getLen()); auto one = builder.createIntegerConstant(loc, len.getType(), 1); - auto upperBound = builder.create(loc, len, one); + auto upperBound = builder.create(loc, len, one); auto lhsLen = builder.createConvert(loc, builder.getIndexType(), lhs.getLen()); Fortran::lower::DoLoopHelper{builder, loc}.createLoop( lhs.getLen(), upperBound, one, [&](Fortran::lower::FirOpBuilder &bldr, mlir::Value index) { - auto rhsIndex = bldr.create(loc, index, lhsLen); + auto rhsIndex = bldr.create(loc, index, lhsLen); auto charVal = createLoadCharAt(rhs, rhsIndex); createStoreCharAt(temp, index, charVal); }); @@ -312,7 +313,8 @@ auto lowerBound = castBounds[0]; // FIR CoordinateOp is zero based but Fortran substring are one based. auto one = builder.createIntegerConstant(loc, lowerBound.getType(), 1); - auto offset = builder.create(loc, lowerBound, one).getResult(); + auto offset = + builder.create(loc, lowerBound, one).getResult(); auto idxType = builder.getIndexType(); if (offset.getType() != idxType) offset = builder.createConvert(loc, idxType, offset); @@ -323,17 +325,17 @@ mlir::Value substringLen{}; if (nbounds < 2) { substringLen = - builder.create(loc, str.getLen(), castBounds[0]); + builder.create(loc, str.getLen(), castBounds[0]); } else { substringLen = - builder.create(loc, castBounds[1], castBounds[0]); + builder.create(loc, castBounds[1], castBounds[0]); } - substringLen = builder.create(loc, substringLen, one); + substringLen = builder.create(loc, substringLen, one); // Set length to zero if bounds were reversed (Fortran 2018 9.4.1) auto zero = builder.createIntegerConstant(loc, substringLen.getType(), 0); - auto cdt = builder.create(loc, mlir::CmpIPredicate::slt, - substringLen, zero); + auto cdt = builder.create( + loc, mlir::arith::CmpIPredicate::slt, substringLen, zero); substringLen = builder.create(loc, cdt, zero, substringLen); return {substringRef, substringLen}; diff --git a/flang/lib/Lower/CharacterRuntime.cpp b/flang/lib/Lower/CharacterRuntime.cpp --- a/flang/lib/Lower/CharacterRuntime.cpp +++ b/flang/lib/Lower/CharacterRuntime.cpp @@ -85,11 +85,10 @@ // Lower character operations //===----------------------------------------------------------------------===// -mlir::Value -Fortran::lower::genRawCharCompare(Fortran::lower::AbstractConverter &converter, - mlir::Location loc, mlir::CmpIPredicate cmp, - mlir::Value lhsBuff, mlir::Value lhsLen, - mlir::Value rhsBuff, mlir::Value rhsLen) { +mlir::Value Fortran::lower::genRawCharCompare( + Fortran::lower::AbstractConverter &converter, mlir::Location loc, + mlir::arith::CmpIPredicate cmp, mlir::Value lhsBuff, mlir::Value lhsLen, + mlir::Value rhsBuff, mlir::Value rhsLen) { auto &builder = converter.getFirOpBuilder(); mlir::FuncOp beginFunc; switch (discoverKind(lhsBuff.getType())) { @@ -113,13 +112,12 @@ llvm::SmallVector args = {lptr, rptr, llen, rlen}; auto tri = builder.create(loc, beginFunc, args).getResult(0); auto zero = builder.createIntegerConstant(loc, tri.getType(), 0); - return builder.create(loc, cmp, tri, zero); + return builder.create(loc, cmp, tri, zero); } -mlir::Value -Fortran::lower::genBoxCharCompare(Fortran::lower::AbstractConverter &converter, - mlir::Location loc, mlir::CmpIPredicate cmp, - mlir::Value lhs, mlir::Value rhs) { +mlir::Value Fortran::lower::genBoxCharCompare( + Fortran::lower::AbstractConverter &converter, mlir::Location loc, + mlir::arith::CmpIPredicate cmp, mlir::Value lhs, mlir::Value rhs) { auto &builder = converter.getFirOpBuilder(); Fortran::lower::CharacterExprHelper helper{builder, loc}; auto lhsPair = helper.materializeCharacter(lhs); diff --git a/flang/lib/Lower/ComplexExpr.cpp b/flang/lib/Lower/ComplexExpr.cpp --- a/flang/lib/Lower/ComplexExpr.cpp +++ b/flang/lib/Lower/ComplexExpr.cpp @@ -46,13 +46,15 @@ auto imag1 = extract(cplx1); auto imag2 = extract(cplx2); - mlir::CmpFPredicate predicate = - eq ? mlir::CmpFPredicate::UEQ : mlir::CmpFPredicate::UNE; + mlir::arith::CmpFPredicate predicate = + eq ? mlir::arith::CmpFPredicate::UEQ : mlir::arith::CmpFPredicate::UNE; mlir::Value realCmp = - builder.create(loc, predicate, real1, real2); + builder.create(loc, predicate, real1, real2); mlir::Value imagCmp = - builder.create(loc, predicate, imag1, imag2); + builder.create(loc, predicate, imag1, imag2); - return eq ? builder.create(loc, realCmp, imagCmp).getResult() - : builder.create(loc, realCmp, imagCmp).getResult(); + return eq ? builder.create(loc, realCmp, imagCmp) + .getResult() + : builder.create(loc, realCmp, imagCmp) + .getResult(); } diff --git a/flang/lib/Lower/DoLoopHelper.cpp b/flang/lib/Lower/DoLoopHelper.cpp --- a/flang/lib/Lower/DoLoopHelper.cpp +++ b/flang/lib/Lower/DoLoopHelper.cpp @@ -39,6 +39,6 @@ auto indexType = builder.getIndexType(); auto zero = builder.createIntegerConstant(loc, indexType, 0); auto one = builder.createIntegerConstant(loc, count.getType(), 1); - auto up = builder.create(loc, count, one); + auto up = builder.create(loc, count, one); createLoop(zero, up, one, bodyGenerator); } diff --git a/flang/lib/Lower/FIRBuilder.cpp b/flang/lib/Lower/FIRBuilder.cpp --- a/flang/lib/Lower/FIRBuilder.cpp +++ b/flang/lib/Lower/FIRBuilder.cpp @@ -48,12 +48,13 @@ mlir::Value Fortran::lower::FirOpBuilder::createIntegerConstant( mlir::Location loc, mlir::Type ty, std::int64_t cst) { - return create(loc, ty, getIntegerAttr(ty, cst)); + return create(loc, ty, getIntegerAttr(ty, cst)); } mlir::Value Fortran::lower::FirOpBuilder::createRealConstant( mlir::Location loc, mlir::Type realType, const llvm::APFloat &val) { - return create(loc, realType, getFloatAttr(realType, val)); + return create(loc, realType, + getFloatAttr(realType, val)); } mlir::Value @@ -67,7 +68,7 @@ } else { // mlir::FloatType. attr = getZeroAttr(realType); } - return create(loc, realType, attr); + return create(loc, realType, attr); } mlir::Value Fortran::lower::FirOpBuilder::allocateLocal( diff --git a/flang/lib/Lower/IO.cpp b/flang/lib/Lower/IO.cpp --- a/flang/lib/Lower/IO.cpp +++ b/flang/lib/Lower/IO.cpp @@ -319,8 +319,9 @@ auto complexPartAddr = [&](int index) { return builder.create( loc, complexPartType, originalItemAddr, - llvm::SmallVector{builder.create( - loc, builder.getI32IntegerAttr(index))}); + llvm::SmallVector{ + builder.create( + loc, builder.getI32IntegerAttr(index))}); }; if (complexPartType) itemAddr = complexPartAddr(0); // real part @@ -332,7 +333,7 @@ inputFuncArgs.push_back( builder.createConvert(loc, inputFunc.getType().getInput(2), len)); } else if (itemType.isa()) { - inputFuncArgs.push_back(builder.create( + inputFuncArgs.push_back(builder.create( loc, builder.getI32IntegerAttr( itemType.cast().getWidth() / 8))); } @@ -373,7 +374,7 @@ auto upperValue = genFIRLoopIndex(control.upper); auto stepValue = control.step.has_value() ? genFIRLoopIndex(*control.step) - : builder.create(loc, 1); + : builder.create(loc, 1); auto genItemList = [&](const D &ioImpliedDo, bool inIterWhileLoop) { if constexpr (std::is_same_v) genInputItemList(converter, cookie, itemList, insertPt, checkResult, ok, @@ -430,28 +431,28 @@ static mlir::Value getDefaultFilename(Fortran::lower::FirOpBuilder &builder, mlir::Location loc, mlir::Type toType) { - mlir::Value null = - builder.create(loc, builder.getI64IntegerAttr(0)); + mlir::Value null = builder.create( + loc, builder.getI64IntegerAttr(0)); return builder.createConvert(loc, toType, null); } static mlir::Value getDefaultLineNo(Fortran::lower::FirOpBuilder &builder, mlir::Location loc, mlir::Type toType) { - return builder.create(loc, - builder.getIntegerAttr(toType, 0)); + return builder.create( + loc, builder.getIntegerAttr(toType, 0)); } static mlir::Value getDefaultScratch(Fortran::lower::FirOpBuilder &builder, mlir::Location loc, mlir::Type toType) { - mlir::Value null = - builder.create(loc, builder.getI64IntegerAttr(0)); + mlir::Value null = builder.create( + loc, builder.getI64IntegerAttr(0)); return builder.createConvert(loc, toType, null); } static mlir::Value getDefaultScratchLen(Fortran::lower::FirOpBuilder &builder, mlir::Location loc, mlir::Type toType) { - return builder.create(loc, - builder.getIntegerAttr(toType, 0)); + return builder.create( + loc, builder.getIntegerAttr(toType, 0)); } /// Lower a string literal. Many arguments to the runtime are conveyed as @@ -470,7 +471,7 @@ auto len = builder.createConvert(loc, lenTy, dataLen.second); if (ty2) { auto kindVal = helper.getCharacterKind(str.getType()); - auto kind = builder.create( + auto kind = builder.create( loc, builder.getIntegerAttr(ty2, kindVal)); return {buff, len, kind}; } @@ -777,7 +778,7 @@ getIORuntimeFunc(loc, builder); mlir::Type boolType = enableHandlers.getType().getInput(1); auto boolValue = [&](bool specifierIsPresent) { - return builder.create( + return builder.create( loc, builder.getIntegerAttr(boolType, specifierIsPresent)); }; llvm::SmallVector ioArgs = { @@ -998,7 +999,7 @@ auto ex = converter.genExprValue(Fortran::semantics::GetExpr(*e), loc); return builder.createConvert(loc, ty, ex); } - return builder.create( + return builder.create( loc, builder.getIntegerAttr(ty, Fortran::runtime::io::DefaultUnit)); } @@ -1291,7 +1292,7 @@ ioArgs.push_back(std::get<1>(pair)); } // unit (always last) - ioArgs.push_back(builder.create( + ioArgs.push_back(builder.create( loc, builder.getIntegerAttr(ioFuncTy.getInput(ioArgs.size()), Fortran::runtime::io::DefaultUnit))); } diff --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp --- a/flang/lib/Lower/IntrinsicCall.cpp +++ b/flang/lib/Lower/IntrinsicCall.cpp @@ -948,7 +948,7 @@ auto arg = args[0]; auto type = arg.getType(); if (fir::isa_real(type)) { - // Runtime call to fp abs. An alternative would be to use mlir AbsFOp + // Runtime call to fp abs. An alternative would be to use mlir math::AbsOp // but it does not support all fir floating point types. return genRuntimeCall("abs", resultType, args); } @@ -957,9 +957,9 @@ // So, implement abs here without branching. auto shift = builder.createIntegerConstant(loc, intType, intType.getWidth() - 1); - auto mask = builder.create(loc, arg, shift); - auto xored = builder.create(loc, arg, mask); - return builder.create(loc, xored, mask); + auto mask = builder.create(loc, arg, shift); + auto xored = builder.create(loc, arg, mask); + return builder.create(loc, xored, mask); } if (fir::isa_complex(type)) { // Use HYPOT to fulfill the no underflow/overflow requirement. @@ -1021,7 +1021,7 @@ auto imag = Fortran::lower::ComplexExprHelper{builder, loc}.extractComplexPart( cplx, /*isImagPart=*/true); - auto negImag = builder.create(loc, imag); + auto negImag = builder.create(loc, imag); return Fortran::lower::ComplexExprHelper{builder, loc}.insertComplexPart( cplx, negImag, /*isImagPart=*/true); } @@ -1032,16 +1032,16 @@ assert(args.size() == 2); if (resultType.isa()) { auto zero = builder.createIntegerConstant(loc, resultType, 0); - auto diff = builder.create(loc, args[0], args[1]); - auto cmp = - builder.create(loc, mlir::CmpIPredicate::sgt, diff, zero); + auto diff = builder.create(loc, args[0], args[1]); + auto cmp = builder.create( + loc, mlir::arith::CmpIPredicate::sgt, diff, zero); return builder.create(loc, cmp, diff, zero); } assert(fir::isa_real(resultType) && "Only expects real and integer in DIM"); auto zero = builder.createRealZeroConstant(loc, resultType); - auto diff = builder.create(loc, args[0], args[1]); - auto cmp = - builder.create(loc, mlir::CmpFPredicate::OGT, diff, zero); + auto diff = builder.create(loc, args[0], args[1]); + auto cmp = builder.create( + loc, mlir::arith::CmpFPredicate::OGT, diff, zero); return builder.create(loc, cmp, diff, zero); } @@ -1053,7 +1053,7 @@ "Result must be double precision in DPROD"); auto a = builder.createConvert(loc, resultType, args[0]); auto b = builder.createConvert(loc, resultType, args[1]); - return builder.create(loc, a, b); + return builder.create(loc, a, b); } // FLOOR @@ -1072,7 +1072,7 @@ llvm::ArrayRef args) { assert(args.size() == 2); - return builder.create(loc, args[0], args[1]); + return builder.create(loc, args[0], args[1]); } // ICHAR @@ -1096,14 +1096,14 @@ mlir::Value IntrinsicLibrary::genIEOr(mlir::Type resultType, llvm::ArrayRef args) { assert(args.size() == 2); - return builder.create(loc, args[0], args[1]); + return builder.create(loc, args[0], args[1]); } // IOR mlir::Value IntrinsicLibrary::genIOr(mlir::Type resultType, llvm::ArrayRef args) { assert(args.size() == 2); - return builder.create(loc, args[0], args[1]); + return builder.create(loc, args[0], args[1]); } // LEN @@ -1154,12 +1154,12 @@ llvm::ArrayRef args) { assert(args.size() == 2); if (resultType.isa()) - return builder.create(loc, args[0], args[1]); + return builder.create(loc, args[0], args[1]); - // Use runtime. Note that mlir::RemFOp implements floating point + // Use runtime. Note that mlir::arith::RemFOp implements floating point // remainder, but it does not work with fir::Real type. - // TODO: consider using mlir::RemFOp when possible, that may help folding - // and optimizations. + // TODO: consider using mlir::arith::RemFOp when possible, that may help + // folding and optimizations. return genRuntimeCall("mod", resultType, args); } @@ -1179,17 +1179,18 @@ auto abs = genAbs(resultType, {args[0]}); if (resultType.isa()) { auto zero = builder.createIntegerConstant(loc, resultType, 0); - auto neg = builder.create(loc, zero, abs); - auto cmp = builder.create(loc, mlir::CmpIPredicate::slt, - args[1], zero); + auto neg = builder.create(loc, zero, abs); + auto cmp = builder.create( + loc, mlir::arith::CmpIPredicate::slt, args[1], zero); return builder.create(loc, cmp, neg, abs); } // TODO: Requirements when second argument is +0./0. auto zeroAttr = builder.getZeroAttr(resultType); - auto zero = builder.create(loc, resultType, zeroAttr); - auto neg = builder.create(loc, abs); - auto cmp = builder.create(loc, mlir::CmpFPredicate::OLT, - args[1], zero); + auto zero = + builder.create(loc, resultType, zeroAttr); + auto neg = builder.create(loc, abs); + auto cmp = builder.create( + loc, mlir::arith::CmpFPredicate::OLT, args[1], zero); return builder.create(loc, cmp, neg, abs); } @@ -1198,12 +1199,12 @@ static mlir::Value createExtremumCompare(mlir::Location loc, Fortran::lower::FirOpBuilder &builder, mlir::Value left, mlir::Value right) { - static constexpr auto integerPredicate = extremum == Extremum::Max - ? mlir::CmpIPredicate::sgt - : mlir::CmpIPredicate::slt; + static constexpr auto integerPredicate = + extremum == Extremum::Max ? mlir::arith::CmpIPredicate::sgt + : mlir::arith::CmpIPredicate::slt; static constexpr auto orderedCmp = extremum == Extremum::Max - ? mlir::CmpFPredicate::OGT - : mlir::CmpFPredicate::OLT; + ? mlir::arith::CmpFPredicate::OGT + : mlir::arith::CmpFPredicate::OLT; auto type = left.getType(); mlir::Value result; if (fir::isa_real(type)) { @@ -1213,33 +1214,37 @@ // Return the number if one of the inputs is NaN and the other is // a number. auto leftIsResult = - builder.create(loc, orderedCmp, left, right); - auto rightIsNan = builder.create( - loc, mlir::CmpFPredicate::UNE, right, right); - result = builder.create(loc, leftIsResult, rightIsNan); + builder.create(loc, orderedCmp, left, right); + auto rightIsNan = builder.create( + loc, mlir::arith::CmpFPredicate::UNE, right, right); + result = + builder.create(loc, leftIsResult, rightIsNan); } else if constexpr (behavior == ExtremumBehavior::IeeeMinMaximum) { // Always return NaNs if one the input is NaNs auto leftIsResult = - builder.create(loc, orderedCmp, left, right); - auto leftIsNan = builder.create( - loc, mlir::CmpFPredicate::UNE, left, left); - result = builder.create(loc, leftIsResult, leftIsNan); + builder.create(loc, orderedCmp, left, right); + auto leftIsNan = builder.create( + loc, mlir::arith::CmpFPredicate::UNE, left, left); + result = builder.create(loc, leftIsResult, leftIsNan); } else if constexpr (behavior == ExtremumBehavior::MinMaxss) { // If the left is a NaN, return the right whatever it is. - result = builder.create(loc, orderedCmp, left, right); + result = + builder.create(loc, orderedCmp, left, right); } else if constexpr (behavior == ExtremumBehavior::PgfortranLlvm) { // If one of the operand is a NaN, return left whatever it is. - static constexpr auto unorderedCmp = extremum == Extremum::Max - ? mlir::CmpFPredicate::UGT - : mlir::CmpFPredicate::ULT; - result = builder.create(loc, unorderedCmp, left, right); + static constexpr auto unorderedCmp = + extremum == Extremum::Max ? mlir::arith::CmpFPredicate::UGT + : mlir::arith::CmpFPredicate::ULT; + result = + builder.create(loc, unorderedCmp, left, right); } else { // TODO: ieeeMinNum/ieeeMaxNum static_assert(behavior == ExtremumBehavior::IeeeMinMaxNum, "ieeeMinNum/ieeeMaxNum behavior not implemented"); } } else if (fir::isa_integer(type)) { - result = builder.create(loc, integerPredicate, left, right); + result = + builder.create(loc, integerPredicate, left, right); } else if (type.isa()) { // TODO: ! character min and max is tricky because the result // length is the length of the longest argument! diff --git a/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp b/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp --- a/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp @@ -62,11 +62,14 @@ /// ``` /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1> /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1> -/// %3 = fir.embox %0 (%1) [%2] : (!fir.ref>, !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box> +/// %3 = fir.embox %0 (%1) [%2] : (!fir.ref>, +/// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box> /// ``` /// can be rewritten as /// ``` -/// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] : (!fir.ref>, index, index, index, index, index) -> !fir.box> +/// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] : +/// (!fir.ref>, index, index, index, index, index) -> +/// !fir.box> /// ``` class EmboxConversion : public mlir::OpRewritePattern { public: @@ -94,7 +97,7 @@ auto idxTy = rewriter.getIndexType(); for (auto ext : seqTy.getShape()) { auto iAttr = rewriter.getIndexAttr(ext); - auto extVal = rewriter.create(loc, idxTy, iAttr); + auto extVal = rewriter.create(loc, idxTy, iAttr); shapeOpers.push_back(extVal); } auto xbox = rewriter.create( @@ -139,11 +142,13 @@ /// /// For example, /// ``` -/// %5 = fir.rebox %3(%1) : (!fir.box>, !fir.shapeshift<1>) -> !fir.box> +/// %5 = fir.rebox %3(%1) : (!fir.box>, !fir.shapeshift<1>) -> +/// !fir.box> /// ``` /// converted to /// ``` -/// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box>, index, index) -> !fir.box> +/// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box>, +/// index, index) -> !fir.box> /// ``` class ReboxConversion : public mlir::OpRewritePattern { public: @@ -187,11 +192,14 @@ /// /// For example, /// ``` -/// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref>, !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref +/// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref>, +/// !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref /// ``` /// converted to /// ``` -/// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> : (!fir.ref>, index, index, index, index, index, index) -> !fir.ref +/// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> : +/// (!fir.ref>, index, index, index, index, index, index) -> +/// !fir.ref /// ``` class ArrayCoorConversion : public mlir::OpRewritePattern { public: @@ -237,8 +245,8 @@ auto &context = getContext(); mlir::OpBuilder rewriter(&context); mlir::ConversionTarget target(context); - target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalOp(); target.addIllegalOp(); target.addDynamicallyLegalOp([](EmboxOp embox) { diff --git a/flang/lib/Optimizer/Dialect/CMakeLists.txt b/flang/lib/Optimizer/Dialect/CMakeLists.txt --- a/flang/lib/Optimizer/Dialect/CMakeLists.txt +++ b/flang/lib/Optimizer/Dialect/CMakeLists.txt @@ -10,6 +10,7 @@ LINK_LIBS FIRSupport + MLIRArithmetic MLIROpenMPToLLVM MLIRLLVMToLLVMIRTranslation MLIRTargetLLVMIRExport diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -638,12 +638,13 @@ template static void printCmpOp(OpAsmPrinter &p, OPTY op) { p << ' '; - auto predSym = mlir::symbolizeCmpFPredicate( + auto predSym = mlir::arith::symbolizeCmpFPredicate( op->template getAttrOfType( OPTY::getPredicateAttrName()) .getInt()); assert(predSym.hasValue() && "invalid symbol value for predicate"); - p << '"' << mlir::stringifyCmpFPredicate(predSym.getValue()) << '"' << ", "; + p << '"' << mlir::arith::stringifyCmpFPredicate(predSym.getValue()) << '"' + << ", "; p.printOperand(op.lhs()); p << ", "; p.printOperand(op.rhs()); @@ -706,7 +707,7 @@ //===----------------------------------------------------------------------===// void fir::buildCmpCOp(OpBuilder &builder, OperationState &result, - CmpFPredicate predicate, Value lhs, Value rhs) { + arith::CmpFPredicate predicate, Value lhs, Value rhs) { result.addOperands({lhs, rhs}); result.types.push_back(builder.getI1Type()); result.addAttribute( @@ -714,8 +715,9 @@ builder.getI64IntegerAttr(static_cast(predicate))); } -mlir::CmpFPredicate fir::CmpcOp::getPredicateByName(llvm::StringRef name) { - auto pred = mlir::symbolizeCmpFPredicate(name); +mlir::arith::CmpFPredicate +fir::CmpcOp::getPredicateByName(llvm::StringRef name) { + auto pred = mlir::arith::symbolizeCmpFPredicate(name); assert(pred.hasValue() && "invalid predicate name"); return pred.getValue(); } @@ -1276,9 +1278,9 @@ static void appendAsAttribute(llvm::SmallVectorImpl &attrs, mlir::Value val) { if (auto *op = val.getDefiningOp()) { - if (auto cop = mlir::dyn_cast(op)) { + if (auto cop = mlir::dyn_cast(op)) { // append the integer constant value - if (auto iattr = cop.getValue().dyn_cast()) { + if (auto iattr = cop.value().dyn_cast()) { attrs.push_back(iattr); return; } @@ -1505,8 +1507,8 @@ void fir::InsertValueOp::getCanonicalizationPatterns( mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - results.insert, - UndoComplexPattern>(context); + results.insert, + UndoComplexPattern>(context); } //===----------------------------------------------------------------------===// @@ -3239,7 +3241,7 @@ if (auto *op = (*i++).getDefiningOp()) { if (auto off = mlir::dyn_cast(op)) return ty.getType(off.getFieldName()); - if (auto off = mlir::dyn_cast(op)) + if (auto off = mlir::dyn_cast(op)) return ty.getType(fir::toInt(off)); } return mlir::Type{}; @@ -3254,7 +3256,7 @@ }) .Case([&](mlir::TupleType ty) { if (auto *op = (*i++).getDefiningOp()) - if (auto off = mlir::dyn_cast(op)) + if (auto off = mlir::dyn_cast(op)) return ty.getType(fir::toInt(off)); return mlir::Type{}; }) diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -248,7 +248,8 @@ return; // Convert the calls and, if needed, the ReturnOp in the function body. - target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalOp(); target.addDynamicallyLegalOp([](fir::CallOp call) { return !mustConvertCallOrFunc(call.getFunctionType()); diff --git a/flang/lib/Optimizer/Transforms/AffineDemotion.cpp b/flang/lib/Optimizer/Transforms/AffineDemotion.cpp --- a/flang/lib/Optimizer/Transforms/AffineDemotion.cpp +++ b/flang/lib/Optimizer/Transforms/AffineDemotion.cpp @@ -144,6 +144,7 @@ return true; }); target.addLegalDialect(); if (mlir::failed(mlir::applyPartialConversion(function, target, diff --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp --- a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp +++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp @@ -157,7 +157,7 @@ using MaybeAffineExpr = llvm::Optional; explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) { - if (auto condDef = firCondition.getDefiningOp()) + if (auto condDef = firCondition.getDefiningOp()) fromCmpIOp(condDef); } @@ -193,19 +193,19 @@ /// in an affine expression, this includes -, +, *, rem, constant. /// block arguments of a loopOp or forOp are used as dimensions MaybeAffineExpr toAffineExpr(mlir::Value value) { - if (auto op = value.getDefiningOp()) + if (auto op = value.getDefiningOp()) return affineBinaryOp(mlir::AffineExprKind::Add, toAffineExpr(op.lhs()), affineBinaryOp(mlir::AffineExprKind::Mul, toAffineExpr(op.rhs()), toAffineExpr(-1))); - if (auto op = value.getDefiningOp()) + if (auto op = value.getDefiningOp()) return affineBinaryOp(mlir::AffineExprKind::Add, op.lhs(), op.rhs()); - if (auto op = value.getDefiningOp()) + if (auto op = value.getDefiningOp()) return affineBinaryOp(mlir::AffineExprKind::Mul, op.lhs(), op.rhs()); - if (auto op = value.getDefiningOp()) + if (auto op = value.getDefiningOp()) return affineBinaryOp(mlir::AffineExprKind::Mod, op.lhs(), op.rhs()); - if (auto op = value.getDefiningOp()) - if (auto intConstant = op.getValue().dyn_cast()) + if (auto op = value.getDefiningOp()) + if (auto intConstant = op.value().dyn_cast()) return toAffineExpr(intConstant.getInt()); if (auto blockArg = value.dyn_cast()) { affineArgs.push_back(value); @@ -217,7 +217,7 @@ return {}; } - void fromCmpIOp(mlir::CmpIOp cmpOp) { + void fromCmpIOp(mlir::arith::CmpIOp cmpOp) { auto lhsAffine = toAffineExpr(cmpOp.lhs()); auto rhsAffine = toAffineExpr(cmpOp.rhs()); if (!lhsAffine.hasValue() || !rhsAffine.hasValue()) @@ -233,17 +233,17 @@ } llvm::Optional> - constraint(mlir::CmpIPredicate predicate, mlir::AffineExpr basic) { + constraint(mlir::arith::CmpIPredicate predicate, mlir::AffineExpr basic) { switch (predicate) { - case mlir::CmpIPredicate::slt: + case mlir::arith::CmpIPredicate::slt: return {std::make_pair(basic - 1, false)}; - case mlir::CmpIPredicate::sle: + case mlir::arith::CmpIPredicate::sle: return {std::make_pair(basic, false)}; - case mlir::CmpIPredicate::sgt: + case mlir::arith::CmpIPredicate::sgt: return {std::make_pair(1 - basic, false)}; - case mlir::CmpIPredicate::sge: + case mlir::arith::CmpIPredicate::sge: return {std::make_pair(0 - basic, false)}; - case mlir::CmpIPredicate::eq: + case mlir::arith::CmpIPredicate::eq: return {std::make_pair(basic, true)}; default: return {}; @@ -315,8 +315,8 @@ } static Optional constantIntegerLike(const mlir::Value value) { - if (auto definition = value.getDefiningOp()) - if (auto stepAttr = definition.getValue().dyn_cast()) + if (auto definition = value.getDefiningOp()) + if (auto stepAttr = definition.value().dyn_cast()) return stepAttr.getInt(); return {}; } @@ -335,7 +335,7 @@ static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape, SmallVectorImpl &indexArgs, mlir::PatternRewriter &rewriter) { - auto one = rewriter.create( + auto one = rewriter.create( acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1)); auto extents = shape.extents(); for (auto i = extents.begin(); i < extents.end(); i++) { @@ -348,7 +348,7 @@ static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape, SmallVectorImpl &indexArgs, mlir::PatternRewriter &rewriter) { - auto one = rewriter.create( + auto one = rewriter.create( acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1)); auto extents = shape.pairs(); for (auto i = extents.begin(); i < extents.end();) { @@ -579,8 +579,9 @@ patterns.insert(context, functionAnalysis); patterns.insert(context, functionAnalysis); mlir::ConversionTarget target = *context; - target.addLegalDialect(); + target.addLegalDialect< + mlir::AffineDialect, FIROpsDialect, mlir::scf::SCFDialect, + mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect>(); target.addDynamicallyLegalOp([&functionAnalysis](fir::IfOp op) { return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine()); }); diff --git a/flang/lib/Optimizer/Transforms/CharacterConversion.cpp b/flang/lib/Optimizer/Transforms/CharacterConversion.cpp --- a/flang/lib/Optimizer/Transforms/CharacterConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CharacterConversion.cpp @@ -43,11 +43,11 @@ << "running character conversion on " << conv << '\n'); // Establish a loop that executes count iterations. - auto zero = rewriter.create(loc, 0); - auto one = rewriter.create(loc, 1); + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); auto idxTy = rewriter.getIndexType(); auto castCnt = rewriter.create(loc, idxTy, conv.count()); - auto countm1 = rewriter.create(loc, castCnt, one); + auto countm1 = rewriter.create(loc, castCnt, one); auto loop = rewriter.create(loc, zero, countm1, one); auto insPt = rewriter.saveInsertionPoint(); rewriter.setInsertionPointToStart(loop.getBody()); @@ -83,7 +83,8 @@ mlir::Value icast = (fromBits >= toBits) ? rewriter.create(loc, toTy, load).getResult() - : rewriter.create(loc, toTy, load).getResult(); + : rewriter.create(loc, toTy, load) + .getResult(); rewriter.replaceOpWithNewOp(conv, icast, toi); rewriter.restoreInsertionPoint(insPt); return mlir::success(); @@ -104,6 +105,7 @@ patterns.insert(context); mlir::ConversionTarget target(*context); target.addLegalDialect(); // apply the patterns diff --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir --- a/flang/test/Fir/abstract-results.fir +++ b/flang/test/Fir/abstract-results.fir @@ -28,9 +28,9 @@ func private @arrayfunc_callee(%n : index) -> !fir.array { %buffer = fir.alloca !fir.array, %n // Do something with result (res(4) = 42.) - %c4 = constant 4 : i64 + %c4 = arith.constant 4 : i64 %coor = fir.coordinate_of %buffer, %c4 : (!fir.ref>, i64) -> !fir.ref - %cst = constant 4.200000e+01 : f32 + %cst = arith.constant 4.200000e+01 : f32 fir.store %cst to %coor : !fir.ref %res = fir.load %buffer : !fir.ref> return %res : !fir.array @@ -90,19 +90,19 @@ // CHECK-LABEL: func @call_arrayfunc() { // CHECK-BOX-LABEL: func @call_arrayfunc() { func @call_arrayfunc() { - %c100 = constant 100 : index + %c100 = arith.constant 100 : index %buffer = fir.alloca !fir.array, %c100 %shape = fir.shape %c100 : (index) -> !fir.shape<1> %res = fir.call @arrayfunc_callee(%c100) : (index) -> !fir.array fir.save_result %res to %buffer(%shape) : !fir.array, !fir.ref>, !fir.shape<1> return - // CHECK: %[[c100:.*]] = constant 100 : index + // CHECK: %[[c100:.*]] = arith.constant 100 : index // CHECK: %[[buffer:.*]] = fir.alloca !fir.array, %[[c100]] // CHECK: fir.call @arrayfunc_callee(%[[buffer]], %[[c100]]) : (!fir.ref>, index) -> () // CHECK-NOT: fir.save_result - // CHECK-BOX: %[[c100:.*]] = constant 100 : index + // CHECK-BOX: %[[c100:.*]] = arith.constant 100 : index // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.array, %[[c100]] // CHECK-BOX: %[[shape:.*]] = fir.shape %[[c100]] : (index) -> !fir.shape<1> // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]](%[[shape]]) : (!fir.ref>, !fir.shape<1>) -> !fir.box> @@ -114,17 +114,17 @@ // CHECK-BOX-LABEL: func @call_derivedfunc() { func @call_derivedfunc() { %buffer = fir.alloca !fir.type - %cst = constant 4.200000e+01 : f32 + %cst = arith.constant 4.200000e+01 : f32 %res = fir.call @derivedfunc_callee(%cst) : (f32) -> !fir.type fir.save_result %res to %buffer : !fir.type, !fir.ref> return // CHECK: %[[buffer:.*]] = fir.alloca !fir.type - // CHECK: %[[cst:.*]] = constant {{.*}} : f32 + // CHECK: %[[cst:.*]] = arith.constant {{.*}} : f32 // CHECK: fir.call @derivedfunc_callee(%[[buffer]], %[[cst]]) : (!fir.ref>, f32) -> () // CHECK-NOT: fir.save_result // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.type - // CHECK-BOX: %[[cst:.*]] = constant {{.*}} : f32 + // CHECK-BOX: %[[cst:.*]] = arith.constant {{.*}} : f32 // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]] : (!fir.ref>) -> !fir.box> // CHECK-BOX: fir.call @derivedfunc_callee(%[[box]], %[[cst]]) : (!fir.box>, f32) -> () // CHECK-BOX-NOT: fir.save_result @@ -137,19 +137,19 @@ // CHECK-BOX-LABEL: func @call_derived_lparams_func( // CHECK-BOX-SAME: %[[buffer:.*]]: !fir.ref> func @call_derived_lparams_func(%buffer: !fir.ref>) { - %l1 = constant 3 : i32 - %l2 = constant 5 : i32 + %l1 = arith.constant 3 : i32 + %l2 = arith.constant 5 : i32 %res = fir.call @derived_lparams_func() : () -> !fir.type fir.save_result %res to %buffer typeparams %l1, %l2 : !fir.type, !fir.ref>, i32, i32 return - // CHECK: %[[l1:.*]] = constant 3 : i32 - // CHECK: %[[l2:.*]] = constant 5 : i32 + // CHECK: %[[l1:.*]] = arith.constant 3 : i32 + // CHECK: %[[l2:.*]] = arith.constant 5 : i32 // CHECK: fir.call @derived_lparams_func(%[[buffer]]) : (!fir.ref>) -> () // CHECK-NOT: fir.save_result - // CHECK-BOX: %[[l1:.*]] = constant 3 : i32 - // CHECK-BOX: %[[l2:.*]] = constant 5 : i32 + // CHECK-BOX: %[[l1:.*]] = arith.constant 3 : i32 + // CHECK-BOX: %[[l2:.*]] = arith.constant 5 : i32 // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]] typeparams %[[l1]], %[[l2]] : (!fir.ref>, i32, i32) -> !fir.box> // CHECK-BOX: fir.call @derived_lparams_func(%[[box]]) : (!fir.box>) -> () // CHECK-BOX-NOT: fir.save_result @@ -177,22 +177,22 @@ // CHECK-LABEL: func @call_chararrayfunc() { // CHECK-BOX-LABEL: func @call_chararrayfunc() { func @call_chararrayfunc() { - %c100 = constant 100 : index - %c50 = constant 50 : index + %c100 = arith.constant 100 : index + %c50 = arith.constant 50 : index %buffer = fir.alloca !fir.array>(%c100 : index), %c50 %shape = fir.shape %c100 : (index) -> !fir.shape<1> %res = fir.call @chararrayfunc(%c100, %c50) : (index, index) -> !fir.array> fir.save_result %res to %buffer(%shape) typeparams %c50 : !fir.array>, !fir.ref>>, !fir.shape<1>, index return - // CHECK: %[[c100:.*]] = constant 100 : index - // CHECK: %[[c50:.*]] = constant 50 : index + // CHECK: %[[c100:.*]] = arith.constant 100 : index + // CHECK: %[[c50:.*]] = arith.constant 50 : index // CHECK: %[[buffer:.*]] = fir.alloca !fir.array>(%[[c100]] : index), %[[c50]] // CHECK: fir.call @chararrayfunc(%[[buffer]], %[[c100]], %[[c50]]) : (!fir.ref>>, index, index) -> () // CHECK-NOT: fir.save_result - // CHECK-BOX: %[[c100:.*]] = constant 100 : index - // CHECK-BOX: %[[c50:.*]] = constant 50 : index + // CHECK-BOX: %[[c100:.*]] = arith.constant 100 : index + // CHECK-BOX: %[[c50:.*]] = arith.constant 50 : index // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.array>(%[[c100]] : index), %[[c50]] // CHECK-BOX: %[[shape:.*]] = fir.shape %[[c100]] : (index) -> !fir.shape<1> // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]](%[[shape]]) typeparams %[[c50]] : (!fir.ref>>, !fir.shape<1>, index) -> !fir.box>> @@ -228,7 +228,7 @@ // CHECK-BOX-LABEL: func @test_indirect_calls( // CHECK-BOX-SAME: %[[arg0:.*]]: () -> ()) { func @test_indirect_calls(%arg0: () -> ()) { - %c100 = constant 100 : index + %c100 = arith.constant 100 : index %buffer = fir.alloca !fir.array, %c100 %shape = fir.shape %c100 : (index) -> !fir.shape<1> %0 = fir.convert %arg0 : (() -> ()) -> ((index) -> !fir.array) @@ -236,7 +236,7 @@ fir.save_result %res to %buffer(%shape) : !fir.array, !fir.ref>, !fir.shape<1> return - // CHECK: %[[c100:.*]] = constant 100 : index + // CHECK: %[[c100:.*]] = arith.constant 100 : index // CHECK: %[[buffer:.*]] = fir.alloca !fir.array, %[[c100]] // CHECK: %[[shape:.*]] = fir.shape %[[c100]] : (index) -> !fir.shape<1> // CHECK: %[[original_conv:.*]] = fir.convert %[[arg0]] : (() -> ()) -> ((index) -> !fir.array) @@ -244,7 +244,7 @@ // CHECK: fir.call %[[conv]](%[[buffer]], %c100) : (!fir.ref>, index) -> () // CHECK-NOT: fir.save_result - // CHECK-BOX: %[[c100:.*]] = constant 100 : index + // CHECK-BOX: %[[c100:.*]] = arith.constant 100 : index // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.array, %[[c100]] // CHECK-BOX: %[[shape:.*]] = fir.shape %[[c100]] : (index) -> !fir.shape<1> // CHECK-BOX: %[[original_conv:.*]] = fir.convert %[[arg0]] : (() -> ()) -> ((index) -> !fir.array) diff --git a/flang/test/Fir/affine-demotion.fir b/flang/test/Fir/affine-demotion.fir --- a/flang/test/Fir/affine-demotion.fir +++ b/flang/test/Fir/affine-demotion.fir @@ -7,8 +7,8 @@ #map2 = affine_map<(d0)[s0, s1, s2] -> (d0 * s2 - s0)> module { func @calc(%arg0: !fir.ref>, %arg1: !fir.ref>, %arg2: !fir.ref>) { - %c1 = constant 1 : index - %c100 = constant 100 : index + %c1 = arith.constant 1 : index + %c100 = arith.constant 100 : index %0 = fir.shape %c100 : (index) -> !fir.shape<1> %1 = affine.apply #map0()[%c1, %c100] %2 = fir.alloca !fir.array, %1 @@ -19,7 +19,7 @@ %7 = affine.apply #map2(%arg3)[%c1, %c100, %c1] %8 = affine.load %3[%7] : memref %9 = affine.load %4[%7] : memref - %10 = addf %8, %9 : f32 + %10 = arith.addf %8, %9 : f32 affine.store %10, %5[%7] : memref } %6 = fir.convert %arg2 : (!fir.ref>) -> memref @@ -27,7 +27,7 @@ %7 = affine.apply #map2(%arg3)[%c1, %c100, %c1] %8 = affine.load %5[%7] : memref %9 = affine.load %4[%7] : memref - %10 = mulf %8, %9 : f32 + %10 = arith.mulf %8, %9 : f32 affine.store %10, %6[%7] : memref } return @@ -35,10 +35,10 @@ } // CHECK: func @calc(%[[VAL_0:.*]]: !fir.ref>, %[[VAL_1:.*]]: !fir.ref>, %[[VAL_2:.*]]: !fir.ref>) { -// CHECK: %[[VAL_3:.*]] = constant 1 : index -// CHECK: %[[VAL_4:.*]] = constant 100 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 100 : index // CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1> -// CHECK: %[[VAL_6:.*]] = constant 100 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 100 : index // CHECK: %[[VAL_7:.*]] = fir.alloca !fir.array, %[[VAL_6]] // CHECK: %[[VAL_8:.*]] = fir.convert %[[VAL_0]] : (!fir.ref>) -> !fir.ref> // CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_1]] : (!fir.ref>) -> !fir.ref> @@ -49,7 +49,7 @@ // CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref // CHECK: %[[VAL_15:.*]] = fir.coordinate_of %[[VAL_9]], %[[VAL_12]] : (!fir.ref>, index) -> !fir.ref // CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_15]] : !fir.ref -// CHECK: %[[VAL_17:.*]] = addf %[[VAL_14]], %[[VAL_16]] : f32 +// CHECK: %[[VAL_17:.*]] = arith.addf %[[VAL_14]], %[[VAL_16]] : f32 // CHECK: %[[VAL_18:.*]] = fir.coordinate_of %[[VAL_10]], %[[VAL_12]] : (!fir.ref>, index) -> !fir.ref // CHECK: fir.store %[[VAL_17]] to %[[VAL_18]] : !fir.ref // CHECK: } @@ -60,7 +60,7 @@ // CHECK: %[[VAL_23:.*]] = fir.load %[[VAL_22]] : !fir.ref // CHECK: %[[VAL_24:.*]] = fir.coordinate_of %[[VAL_9]], %[[VAL_21]] : (!fir.ref>, index) -> !fir.ref // CHECK: %[[VAL_25:.*]] = fir.load %[[VAL_24]] : !fir.ref -// CHECK: %[[VAL_26:.*]] = mulf %[[VAL_23]], %[[VAL_25]] : f32 +// CHECK: %[[VAL_26:.*]] = arith.mulf %[[VAL_23]], %[[VAL_25]] : f32 // CHECK: %[[VAL_27:.*]] = fir.coordinate_of %[[VAL_19]], %[[VAL_21]] : (!fir.ref>, index) -> !fir.ref // CHECK: fir.store %[[VAL_26]] to %[[VAL_27]] : !fir.ref // CHECK: } diff --git a/flang/test/Fir/affine-promotion.fir b/flang/test/Fir/affine-promotion.fir --- a/flang/test/Fir/affine-promotion.fir +++ b/flang/test/Fir/affine-promotion.fir @@ -6,9 +6,9 @@ #arr_len = affine_map<()[j1,k1] -> (k1 - j1 + 1)> func @loop_with_load_and_store(%a1: !arr_d1, %a2: !arr_d1, %a3: !arr_d1) { - %c1 = constant 1 : index - %c0 = constant 0 : index - %len = constant 100 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %len = arith.constant 100 : index %dims = fir.shape %len : (index) -> !fir.shape<1> %siz = affine.apply #arr_len()[%c1,%len] %t1 = fir.alloca !fir.array, %siz @@ -22,7 +22,7 @@ : (!arr_d1, !fir.shape<1>, index) -> !fir.ref %a2_v = fir.load %a2_idx : !fir.ref - %v = addf %a1_v, %a2_v : f32 + %v = arith.addf %a1_v, %a2_v : f32 %t1_idx = fir.array_coor %t1(%dims) %i : (!arr_d1, !fir.shape<1>, index) -> !fir.ref @@ -37,7 +37,7 @@ : (!arr_d1, !fir.shape<1>, index) -> !fir.ref %a2_v = fir.load %a2_idx : !fir.ref - %v = mulf %t1_v, %a2_v : f32 + %v = arith.mulf %t1_v, %a2_v : f32 %a3_idx = fir.array_coor %a3(%dims) %i : (!arr_d1, !fir.shape<1>, index) -> !fir.ref @@ -47,8 +47,8 @@ } // CHECK: func @loop_with_load_and_store(%[[VAL_0:.*]]: !fir.ref>, %[[VAL_1:.*]]: !fir.ref>, %[[VAL_2:.*]]: !fir.ref>) { -// CHECK: %[[VAL_3:.*]] = constant 1 : index -// CHECK: %[[VAL_4:.*]] = constant 100 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 100 : index // CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1> // CHECK: %[[VAL_6:.*]] = affine.apply #map0(){{\[}}%[[VAL_3]], %[[VAL_4]]] // CHECK: %[[VAL_7:.*]] = fir.alloca !fir.array, %[[VAL_6]] @@ -59,7 +59,7 @@ // CHECK: %[[VAL_12:.*]] = affine.apply #map2(%[[VAL_11]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]] // CHECK: %[[VAL_13:.*]] = affine.load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref // CHECK: %[[VAL_14:.*]] = affine.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref -// CHECK: %[[VAL_15:.*]] = addf %[[VAL_13]], %[[VAL_14]] : f32 +// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 // CHECK: affine.store %[[VAL_15]], %[[VAL_10]]{{\[}}%[[VAL_12]]] : memref // CHECK: } // CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_2]] : (!fir.ref>) -> memref @@ -67,7 +67,7 @@ // CHECK: %[[VAL_18:.*]] = affine.apply #map2(%[[VAL_17]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]] // CHECK: %[[VAL_19:.*]] = affine.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref // CHECK: %[[VAL_20:.*]] = affine.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref -// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_19]], %[[VAL_20]] : f32 +// CHECK: %[[VAL_21:.*]] = arith.mulf %[[VAL_19]], %[[VAL_20]] : f32 // CHECK: affine.store %[[VAL_21]], %[[VAL_16]]{{\[}}%[[VAL_18]]] : memref // CHECK: } // CHECK: return @@ -79,17 +79,17 @@ #arr_len = affine_map<()[j1,k1] -> (k1 - j1 + 1)> func @loop_with_if(%a: !arr_d1, %v: f32) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %c2 = constant 2 : index - %len = constant 100 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %len = arith.constant 100 : index %dims = fir.shape %len : (index) -> !fir.shape<1> fir.do_loop %i = %c1 to %len step %c1 { fir.do_loop %j = %c1 to %len step %c1 { fir.do_loop %k = %c1 to %len step %c1 { - %im2 = subi %i, %c2 : index - %cond = cmpi "sgt", %im2, %c0 : index + %im2 = arith.subi %i, %c2 : index + %cond = arith.cmpi "sgt", %im2, %c0 : index fir.if %cond { %a_idx = fir.array_coor %a(%dims) %i : (!arr_d1, !fir.shape<1>, index) -> !fir.ref @@ -108,10 +108,10 @@ } // CHECK: func @loop_with_if(%[[VAL_0:.*]]: !fir.ref>, %[[VAL_1:.*]]: f32) { -// CHECK: %[[VAL_2:.*]] = constant 0 : index -// CHECK: %[[VAL_3:.*]] = constant 1 : index -// CHECK: %[[VAL_4:.*]] = constant 2 : index -// CHECK: %[[VAL_5:.*]] = constant 100 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 100 : index // CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1> // CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_0]] : (!fir.ref>) -> memref // CHECK: affine.for %[[VAL_8:.*]] = %[[VAL_3]] to #map0(){{\[}}%[[VAL_5]]] { @@ -123,7 +123,7 @@ // CHECK: affine.store %[[VAL_1]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref // CHECK: } // CHECK: affine.for %[[VAL_12:.*]] = %[[VAL_3]] to #map0(){{\[}}%[[VAL_5]]] { -// CHECK: %[[VAL_13:.*]] = subi %[[VAL_12]], %[[VAL_4]] : index +// CHECK: %[[VAL_13:.*]] = arith.subi %[[VAL_12]], %[[VAL_4]] : index // CHECK: affine.if #set(%[[VAL_12]]) { // CHECK: %[[VAL_14:.*]] = affine.apply #map1(%[[VAL_12]]){{\[}}%[[VAL_3]], %[[VAL_5]], %[[VAL_3]]] // CHECK: affine.store %[[VAL_1]], %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref diff --git a/flang/test/Fir/cg-ops.fir b/flang/test/Fir/cg-ops.fir --- a/flang/test/Fir/cg-ops.fir +++ b/flang/test/Fir/cg-ops.fir @@ -3,8 +3,8 @@ // CHECK-LABEL: func @codegen( // CHECK-SAME: %[[arg:.*]]: !fir func @codegen(%addr : !fir.ref>) { - // CHECK: %[[zero:.*]] = constant 0 : index - %0 = constant 0 : index + // CHECK: %[[zero:.*]] = arith.constant 0 : index + %0 = arith.constant 0 : index %1 = fir.shape_shift %0, %0 : (index, index) -> !fir.shapeshift<1> %2 = fir.slice %0, %0, %0 : (index, index, index) -> !fir.slice<1> // CHECK: %[[box:.*]] = fircg.ext_embox %[[arg]](%[[zero]]) origin %[[zero]][%[[zero]], %[[zero]], %[[zero]]] : (!fir.ref>, index, index, index, index, index) -> !fir.box> @@ -20,8 +20,8 @@ fir.global @box_global : !fir.box> { // CHECK: %[[arr:.*]] = fir.zero_bits !fir.ref %arr = fir.zero_bits !fir.ref> - // CHECK: %[[zero:.*]] = constant 0 : index - %0 = constant 0 : index + // CHECK: %[[zero:.*]] = arith.constant 0 : index + %0 = arith.constant 0 : index %1 = fir.shape_shift %0, %0 : (index, index) -> !fir.shapeshift<1> %2 = fir.slice %0, %0, %0 : (index, index, index) -> !fir.slice<1> // CHECK: fircg.ext_embox %[[arr]](%[[zero]]) origin %[[zero]][%[[zero]], %[[zero]], %[[zero]]] : (!fir.ref>, index, index, index, index, index) -> !fir.box> diff --git a/flang/test/Fir/char-conversion.fir b/flang/test/Fir/char-conversion.fir --- a/flang/test/Fir/char-conversion.fir +++ b/flang/test/Fir/char-conversion.fir @@ -12,17 +12,17 @@ // CHECK: %[[VAL_0:.*]] = fir.undefined i32 // CHECK: %[[VAL_1:.*]] = fir.undefined !fir.ref> // CHECK: %[[VAL_2:.*]] = fir.undefined !fir.ref>> -// CHECK: %[[VAL_3:.*]] = constant 0 : index -// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_0]] : (i32) -> index -// CHECK: %[[VAL_6:.*]] = subi %[[VAL_5]], %[[VAL_4]] : index +// CHECK: %[[VAL_6:.*]] = arith.subi %[[VAL_5]], %[[VAL_4]] : index // CHECK: fir.do_loop %[[VAL_7:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_4]] { // CHECK: %[[VAL_8:.*]] = fir.convert %[[VAL_1]] : (!fir.ref>) -> !fir.ref> // CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_2]] : (!fir.ref>>) -> !fir.ref> // CHECK: %[[VAL_10:.*]] = fir.coordinate_of %[[VAL_8]], %[[VAL_7]] : (!fir.ref>, index) -> !fir.ref // CHECK: %[[VAL_11:.*]] = fir.coordinate_of %[[VAL_9]], %[[VAL_7]] : (!fir.ref>, index) -> !fir.ref // CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_10]] : !fir.ref -// CHECK: %[[VAL_13:.*]] = zexti %[[VAL_12]] : i8 to i16 +// CHECK: %[[VAL_13:.*]] = arith.extui %[[VAL_12]] : i8 to i16 // CHECK: fir.store %[[VAL_13]] to %[[VAL_11]] : !fir.ref // CHECK: } // CHECK: return diff --git a/flang/test/Fir/convert-fold.fir b/flang/test/Fir/convert-fold.fir --- a/flang/test/Fir/convert-fold.fir +++ b/flang/test/Fir/convert-fold.fir @@ -29,9 +29,9 @@ // CHECK-LABEL: @ctest func @ctest() -> index { - %1 = constant 10 : i32 + %1 = arith.constant 10 : i32 %2 = fir.convert %1 : (i32) -> index - // CHECK-NEXT: %{{.*}} = constant 10 : index + // CHECK-NEXT: %{{.*}} = arith.constant 10 : index // CHECK-NEXT: return %{{.*}} : index return %2 : index } diff --git a/flang/test/Fir/external-mangling.fir b/flang/test/Fir/external-mangling.fir --- a/flang/test/Fir/external-mangling.fir +++ b/flang/test/Fir/external-mangling.fir @@ -1,7 +1,7 @@ // RUN: fir-opt --external-name-interop %s | FileCheck %s func @_QPfoo() { - %c0 = constant 0 : index + %c0 = arith.constant 0 : index %0 = fir.address_of(@_QBa) : !fir.ref> %1 = fir.convert %0 : (!fir.ref>) -> !fir.ref> %2 = fir.coordinate_of %1, %c0 : (!fir.ref>, index) -> !fir.ref diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir --- a/flang/test/Fir/fir-ops.fir +++ b/flang/test/Fir/fir-ops.fir @@ -37,11 +37,11 @@ // CHECK: [[VAL_0:%.*]] = fir.alloca !fir.array<10xi32> // CHECK: [[VAL_1:%.*]] = fir.load [[VAL_0]] : !fir.ref> // CHECK: [[VAL_2:%.*]] = fir.alloca i32 -// CHECK: [[VAL_3:%.*]] = constant 22 : i32 +// CHECK: [[VAL_3:%.*]] = arith.constant 22 : i32 %0 = fir.alloca !fir.array<10xi32> %1 = fir.load %0 : !fir.ref> %2 = fir.alloca i32 - %3 = constant 22 : i32 + %3 = arith.constant 22 : i32 // CHECK: fir.store [[VAL_3]] to [[VAL_2]] : !fir.ref // CHECK: [[VAL_4:%.*]] = fir.undefined i32 @@ -53,12 +53,12 @@ %6 = fir.embox %5 : (!fir.heap>) -> !fir.box> // CHECK: [[VAL_7:%.*]] = fir.box_addr [[VAL_6]] : (!fir.box>) -> !fir.ref> -// CHECK: [[VAL_8:%.*]] = constant 0 : index +// CHECK: [[VAL_8:%.*]] = arith.constant 0 : index // CHECK: [[VAL_9:%.*]]:3 = fir.box_dims [[VAL_6]], [[VAL_8]] : (!fir.box>, index) -> (index, index, index) // CHECK: fir.call @print_index3([[VAL_9]]#0, [[VAL_9]]#1, [[VAL_9]]#2) : (index, index, index) -> () // CHECK: [[VAL_10:%.*]] = fir.call @it1() : () -> !fir.int<4> %7 = fir.box_addr %6 : (!fir.box>) -> !fir.ref> - %c0 = constant 0 : index + %c0 = arith.constant 0 : index %d1:3 = fir.box_dims %6, %c0 : (!fir.box>, index) -> (index, index, index) fir.call @print_index3(%d1#0, %d1#1, %d1#2) : (index, index, index) -> () %8 = fir.call @it1() : () -> !fir.int<4> @@ -85,25 +85,25 @@ %17 = fir.call @box2() : () -> !fir.boxproc<(i32, i32) -> i64> %18 = fir.boxproc_host %17 : (!fir.boxproc<(i32, i32) -> i64>) -> !fir.ref -// CHECK: [[VAL_21:%.*]] = constant 10 : i32 +// CHECK: [[VAL_21:%.*]] = arith.constant 10 : i32 // CHECK: [[VAL_22:%.*]] = fir.coordinate_of [[VAL_5]], [[VAL_21]] : (!fir.heap>, i32) -> !fir.ref // CHECK: [[VAL_23:%.*]] = fir.field_index f, !fir.type // CHECK: [[VAL_24:%.*]] = fir.undefined !fir.type // CHECK: [[VAL_25:%.*]] = fir.extract_value [[VAL_24]], ["f", !fir.type] : (!fir.type) -> f32 - %19 = constant 10 : i32 + %19 = arith.constant 10 : i32 %20 = fir.coordinate_of %5, %19 : (!fir.heap>, i32) -> !fir.ref %21 = fir.field_index f, !fir.type %22 = fir.undefined !fir.type %23 = fir.extract_value %22, ["f", !fir.type] : (!fir.type) -> f32 -// CHECK: [[VAL_26:%.*]] = constant 1 : i32 +// CHECK: [[VAL_26:%.*]] = arith.constant 1 : i32 // CHECK: [[VAL_27:%.*]] = fir.shape [[VAL_21]] : (i32) -> !fir.shape<1> -// CHECK: [[VAL_28:%.*]] = constant 1.0 +// CHECK: [[VAL_28:%.*]] = arith.constant 1.0 // CHECK: [[VAL_29:%.*]] = fir.insert_value [[VAL_24]], [[VAL_28]], ["f", !fir.type] : (!fir.type, f32) -> !fir.type // CHECK: [[VAL_30:%.*]] = fir.len_param_index f, !fir.type - %c1 = constant 1 : i32 + %c1 = arith.constant 1 : i32 %24 = fir.shape %19 : (i32) -> !fir.shape<1> - %cf1 = constant 1.0 : f32 + %cf1 = arith.constant 1.0 : f32 %25 = fir.insert_value %22, %cf1, ["f", !fir.type] : (!fir.type, f32) -> !fir.type %26 = fir.len_param_index f, !fir.type @@ -143,16 +143,16 @@ // CHECK: [[VAL_41:%.*]] = fir.alloca tuple // CHECK: [[VAL_42:%.*]] = fir.embox [[VAL_38]] : (!fir.ref) -> !fir.box // CHECK: [[VAL_43:%.*]]:6 = fir.unbox [[VAL_42]] : (!fir.box) -> (!fir.ref, i32, i32, !fir.tdesc, i32, !fir.array<3x?xindex>) -// CHECK: [[VAL_44:%.*]] = constant 8 : i32 +// CHECK: [[VAL_44:%.*]] = arith.constant 8 : i32 // CHECK: [[VAL_45:%.*]] = fir.undefined !fir.char<1> // CHECK: [[VAL_46:%.*]] = fir.emboxchar [[VAL_40]], [[VAL_44]] : (!fir.ref>, i32) -> !fir.boxchar<1> // CHECK: [[VAL_47:%.*]]:2 = fir.unboxchar [[VAL_46]] : (!fir.boxchar<1>) -> (!fir.ref>, i32) // CHECK: [[VAL_48:%.*]] = fir.undefined !fir.type -// CHECK: [[VAL_49:%.*]] = constant 0 : i32 -// CHECK: [[VAL_50:%.*]] = constant 12 : i32 +// CHECK: [[VAL_49:%.*]] = arith.constant 0 : i32 +// CHECK: [[VAL_50:%.*]] = arith.constant 12 : i32 // CHECK: [[VAL_51:%.*]] = fir.insert_value [[VAL_48]], [[VAL_50]], [0 : i32] : (!fir.type, i32) -> !fir.type -// CHECK: [[VAL_52:%.*]] = constant 1 : i32 -// CHECK: [[VAL_53:%.*]] = constant 4.213000e+01 : f64 +// CHECK: [[VAL_52:%.*]] = arith.constant 1 : i32 +// CHECK: [[VAL_53:%.*]] = arith.constant 4.213000e+01 : f64 // CHECK: [[VAL_54:%.*]] = fir.insert_value [[VAL_48]], [[VAL_53]], [1 : i32] : (!fir.type, f64) -> !fir.type // CHECK: fir.store [[VAL_54]] to [[VAL_39]] : !fir.ref> // CHECK: [[VAL_55:%.*]] = fir.emboxproc @method_impl, [[VAL_41]] : ((!fir.box>) -> (), !fir.ref>) -> !fir.boxproc<(!fir.box>) -> ()> @@ -169,16 +169,16 @@ %e6 = fir.alloca tuple %1 = fir.embox %0 : (!fir.ref) -> !fir.box %2:6 = fir.unbox %1 : (!fir.box) -> (!fir.ref,i32,i32,!fir.tdesc,i32,!fir.array<3x?xindex>) - %c8 = constant 8 : i32 + %c8 = arith.constant 8 : i32 %3 = fir.undefined !fir.char<1> %4 = fir.emboxchar %d3, %c8 : (!fir.ref>, i32) -> !fir.boxchar<1> %5:2 = fir.unboxchar %4 : (!fir.boxchar<1>) -> (!fir.ref>, i32) %6 = fir.undefined !fir.type - %z = constant 0 : i32 - %c12 = constant 12 : i32 + %z = arith.constant 0 : i32 + %c12 = arith.constant 12 : i32 %a2 = fir.insert_value %6, %c12, [0 : i32] : (!fir.type, i32) -> !fir.type - %z1 = constant 1 : i32 - %c42 = constant 42.13 : f64 + %z1 = arith.constant 1 : i32 + %c42 = arith.constant 42.13 : f64 %a3 = fir.insert_value %6, %c42, [1 : i32] : (!fir.type, f64) -> !fir.type fir.store %a3 to %d6 : !fir.ref> %7 = fir.emboxproc @method_impl, %e6 : ((!fir.box>) -> (), !fir.ref>) -> !fir.boxproc<(!fir.box>) -> ()> @@ -192,12 +192,12 @@ // CHECK-LABEL: func @loop() { func @loop() { -// CHECK: [[VAL_62:%.*]] = constant 1 : index -// CHECK: [[VAL_63:%.*]] = constant 10 : index -// CHECK: [[VAL_64:%.*]] = constant true - %c1 = constant 1 : index - %c10 = constant 10 : index - %ct = constant true +// CHECK: [[VAL_62:%.*]] = arith.constant 1 : index +// CHECK: [[VAL_63:%.*]] = arith.constant 10 : index +// CHECK: [[VAL_64:%.*]] = arith.constant true + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %ct = arith.constant true // CHECK: fir.do_loop [[VAL_65:%.*]] = [[VAL_62]] to [[VAL_63]] step [[VAL_62]] { // CHECK: fir.if [[VAL_64]] { @@ -220,92 +220,92 @@ // CHECK: func @bar_select([[VAL_66:%.*]]: i32, [[VAL_67:%.*]]: i32) -> i32 { func @bar_select(%arg : i32, %arg2 : i32) -> i32 { -// CHECK: [[VAL_68:%.*]] = constant 1 : i32 -// CHECK: [[VAL_69:%.*]] = constant 2 : i32 -// CHECK: [[VAL_70:%.*]] = constant 3 : i32 -// CHECK: [[VAL_71:%.*]] = constant 4 : i32 - %0 = constant 1 : i32 - %1 = constant 2 : i32 - %2 = constant 3 : i32 - %3 = constant 4 : i32 +// CHECK: [[VAL_68:%.*]] = arith.constant 1 : i32 +// CHECK: [[VAL_69:%.*]] = arith.constant 2 : i32 +// CHECK: [[VAL_70:%.*]] = arith.constant 3 : i32 +// CHECK: [[VAL_71:%.*]] = arith.constant 4 : i32 + %0 = arith.constant 1 : i32 + %1 = arith.constant 2 : i32 + %2 = arith.constant 3 : i32 + %3 = arith.constant 4 : i32 // CHECK: fir.select [[VAL_66]] : i32 [1, ^bb1([[VAL_68]] : i32), 2, ^bb2([[VAL_70]], [[VAL_66]], [[VAL_67]] : i32, i32, i32), -3, ^bb3([[VAL_67]], [[VAL_70]] : i32, i32), 4, ^bb4([[VAL_69]] : i32), unit, ^bb5] // CHECK: ^bb1([[VAL_72:%.*]]: i32): // CHECK: return [[VAL_72]] : i32 // CHECK: ^bb2([[VAL_73:%.*]]: i32, [[VAL_74:%.*]]: i32, [[VAL_75:%.*]]: i32): -// CHECK: [[VAL_76:%.*]] = addi [[VAL_73]], [[VAL_74]] : i32 -// CHECK: [[VAL_77:%.*]] = addi [[VAL_76]], [[VAL_75]] : i32 +// CHECK: [[VAL_76:%.*]] = arith.addi [[VAL_73]], [[VAL_74]] : i32 +// CHECK: [[VAL_77:%.*]] = arith.addi [[VAL_76]], [[VAL_75]] : i32 // CHECK: return [[VAL_77]] : i32 // CHECK: ^bb3([[VAL_78:%.*]]: i32, [[VAL_79:%.*]]: i32): -// CHECK: [[VAL_80:%.*]] = addi [[VAL_78]], [[VAL_79]] : i32 +// CHECK: [[VAL_80:%.*]] = arith.addi [[VAL_78]], [[VAL_79]] : i32 // CHECK: return [[VAL_80]] : i32 // CHECK: ^bb4([[VAL_81:%.*]]: i32): // CHECK: return [[VAL_81]] : i32 // CHECK: ^bb5: -// CHECK: [[VAL_82:%.*]] = constant 0 : i32 +// CHECK: [[VAL_82:%.*]] = arith.constant 0 : i32 // CHECK: return [[VAL_82]] : i32 // CHECK: } fir.select %arg:i32 [ 1,^bb1(%0:i32), 2,^bb2(%2,%arg,%arg2:i32,i32,i32), -3,^bb3(%arg2,%2:i32,i32), 4,^bb4(%1:i32), unit,^bb5 ] ^bb1(%a : i32) : return %a : i32 ^bb2(%b : i32, %b2 : i32, %b3:i32) : - %4 = addi %b, %b2 : i32 - %5 = addi %4, %b3 : i32 + %4 = arith.addi %b, %b2 : i32 + %5 = arith.addi %4, %b3 : i32 return %5 : i32 ^bb3(%c:i32, %c2:i32) : - %6 = addi %c, %c2 : i32 + %6 = arith.addi %c, %c2 : i32 return %6 : i32 ^bb4(%d : i32) : return %d : i32 ^bb5 : - %zero = constant 0 : i32 + %zero = arith.constant 0 : i32 return %zero : i32 } // CHECK-LABEL: func @bar_select_rank( // CHECK-SAME: [[VAL_83:%.*]]: i32, [[VAL_84:%.*]]: i32) -> i32 { func @bar_select_rank(%arg : i32, %arg2 : i32) -> i32 { -// CHECK: [[VAL_85:%.*]] = constant 1 : i32 -// CHECK: [[VAL_86:%.*]] = constant 2 : i32 -// CHECK: [[VAL_87:%.*]] = constant 3 : i32 -// CHECK: [[VAL_88:%.*]] = constant 4 : i32 - %0 = constant 1 : i32 - %1 = constant 2 : i32 - %2 = constant 3 : i32 - %3 = constant 4 : i32 +// CHECK: [[VAL_85:%.*]] = arith.constant 1 : i32 +// CHECK: [[VAL_86:%.*]] = arith.constant 2 : i32 +// CHECK: [[VAL_87:%.*]] = arith.constant 3 : i32 +// CHECK: [[VAL_88:%.*]] = arith.constant 4 : i32 + %0 = arith.constant 1 : i32 + %1 = arith.constant 2 : i32 + %2 = arith.constant 3 : i32 + %3 = arith.constant 4 : i32 // CHECK: fir.select_rank [[VAL_83]] : i32 [1, ^bb1([[VAL_85]] : i32), 2, ^bb2([[VAL_87]], [[VAL_83]], [[VAL_84]] : i32, i32, i32), 3, ^bb3([[VAL_84]], [[VAL_87]] : i32, i32), -1, ^bb4([[VAL_86]] : i32), unit, ^bb5] // CHECK: ^bb1([[VAL_89:%.*]]: i32): // CHECK: return [[VAL_89]] : i32 // CHECK: ^bb2([[VAL_90:%.*]]: i32, [[VAL_91:%.*]]: i32, [[VAL_92:%.*]]: i32): -// CHECK: [[VAL_93:%.*]] = addi [[VAL_90]], [[VAL_91]] : i32 -// CHECK: [[VAL_94:%.*]] = addi [[VAL_93]], [[VAL_92]] : i32 +// CHECK: [[VAL_93:%.*]] = arith.addi [[VAL_90]], [[VAL_91]] : i32 +// CHECK: [[VAL_94:%.*]] = arith.addi [[VAL_93]], [[VAL_92]] : i32 // CHECK: return [[VAL_94]] : i32 fir.select_rank %arg:i32 [ 1,^bb1(%0:i32), 2,^bb2(%2,%arg,%arg2:i32,i32,i32), 3,^bb3(%arg2,%2:i32,i32), -1,^bb4(%1:i32), unit,^bb5 ] ^bb1(%a : i32) : return %a : i32 ^bb2(%b : i32, %b2 : i32, %b3:i32) : - %4 = addi %b, %b2 : i32 - %5 = addi %4, %b3 : i32 + %4 = arith.addi %b, %b2 : i32 + %5 = arith.addi %4, %b3 : i32 return %5 : i32 // CHECK: ^bb3([[VAL_95:%.*]]: i32, [[VAL_96:%.*]]: i32): -// CHECK: [[VAL_97:%.*]] = addi [[VAL_95]], [[VAL_96]] : i32 +// CHECK: [[VAL_97:%.*]] = arith.addi [[VAL_95]], [[VAL_96]] : i32 // CHECK: return [[VAL_97]] : i32 // CHECK: ^bb4([[VAL_98:%.*]]: i32): // CHECK: return [[VAL_98]] : i32 ^bb3(%c:i32, %c2:i32) : - %6 = addi %c, %c2 : i32 + %6 = arith.addi %c, %c2 : i32 return %6 : i32 ^bb4(%d : i32) : return %d : i32 // CHECK: ^bb5: -// CHECK: [[VAL_99:%.*]] = constant 0 : i32 +// CHECK: [[VAL_99:%.*]] = arith.constant 0 : i32 // CHECK: [[VAL_100:%.*]] = fir.call @get_method_box() : () -> !fir.box> // CHECK: fir.dispatch "method"([[VAL_100]]) : (!fir.box>) -> () ^bb5 : - %zero = constant 0 : i32 + %zero = arith.constant 0 : i32 %7 = fir.call @get_method_box() : () -> !fir.box> fir.dispatch method(%7) : (!fir.box>) -> () @@ -318,14 +318,14 @@ // CHECK-SAME: [[VAL_101:%.*]]: !fir.box}>>) -> i32 { func @bar_select_type(%arg : !fir.box}>>) -> i32 { -// CHECK: [[VAL_102:%.*]] = constant 1 : i32 -// CHECK: [[VAL_103:%.*]] = constant 2 : i32 -// CHECK: [[VAL_104:%.*]] = constant 3 : i32 -// CHECK: [[VAL_105:%.*]] = constant 4 : i32 - %0 = constant 1 : i32 - %1 = constant 2 : i32 - %2 = constant 3 : i32 - %3 = constant 4 : i32 +// CHECK: [[VAL_102:%.*]] = arith.constant 1 : i32 +// CHECK: [[VAL_103:%.*]] = arith.constant 2 : i32 +// CHECK: [[VAL_104:%.*]] = arith.constant 3 : i32 +// CHECK: [[VAL_105:%.*]] = arith.constant 4 : i32 + %0 = arith.constant 1 : i32 + %1 = arith.constant 2 : i32 + %2 = arith.constant 3 : i32 + %3 = arith.constant 4 : i32 // CHECK: fir.select_type [[VAL_101]] : !fir.box}>> [#fir.instance>, ^bb1([[VAL_102]] : i32), #fir.instance>, ^bb2([[VAL_104]] : i32), #fir.subsumed>, ^bb3([[VAL_104]] : i32), #fir.instance>, ^bb4([[VAL_103]] : i32), unit, ^bb5] fir.select_type %arg : !fir.box}>> [ #fir.instance>,^bb1(%0:i32), #fir.instance>,^bb2(%2:i32), #fir.subsumed>,^bb3(%2:i32), #fir.instance>,^bb4(%1:i32), unit,^bb5 ] @@ -348,25 +348,25 @@ return %d : i32 // CHECK: ^bb5: -// CHECK: [[VAL_110:%.*]] = constant 0 : i32 +// CHECK: [[VAL_110:%.*]] = arith.constant 0 : i32 // CHECK: return [[VAL_110]] : i32 // CHECK: } ^bb5 : - %zero = constant 0 : i32 + %zero = arith.constant 0 : i32 return %zero : i32 } // CHECK-LABEL: func @bar_select_case( // CHECK-SAME: [[VAL_111:%.*]]: i32, [[VAL_112:%.*]]: i32) -> i32 { -// CHECK: [[VAL_113:%.*]] = constant 1 : i32 -// CHECK: [[VAL_114:%.*]] = constant 2 : i32 -// CHECK: [[VAL_115:%.*]] = constant 3 : i32 -// CHECK: [[VAL_116:%.*]] = constant 4 : i32 +// CHECK: [[VAL_113:%.*]] = arith.constant 1 : i32 +// CHECK: [[VAL_114:%.*]] = arith.constant 2 : i32 +// CHECK: [[VAL_115:%.*]] = arith.constant 3 : i32 +// CHECK: [[VAL_116:%.*]] = arith.constant 4 : i32 func @bar_select_case(%arg : i32, %arg2 : i32) -> i32 { - %0 = constant 1 : i32 - %1 = constant 2 : i32 - %2 = constant 3 : i32 - %3 = constant 4 : i32 + %0 = arith.constant 1 : i32 + %1 = arith.constant 2 : i32 + %2 = arith.constant 3 : i32 + %3 = arith.constant 4 : i32 // CHECK: fir.select_case [[VAL_111]] : i32 [#fir.point, [[VAL_113]], ^bb1([[VAL_113]] : i32), #fir.lower, [[VAL_114]], ^bb2([[VAL_115]], [[VAL_111]], [[VAL_112]], [[VAL_114]] : i32, i32, i32, i32), #fir.interval, [[VAL_115]], [[VAL_116]], ^bb3([[VAL_115]], [[VAL_112]] : i32, i32), #fir.upper, [[VAL_111]], ^bb4([[VAL_114]] : i32), unit, ^bb5] fir.select_case %arg : i32 [#fir.point, %0, ^bb1(%0:i32), #fir.lower, %1, ^bb2(%2,%arg,%arg2,%1:i32,i32,i32,i32), #fir.interval, %2, %3, ^bb3(%2,%arg2:i32,i32), #fir.upper, %arg, ^bb4(%1:i32), unit, ^bb5] @@ -374,52 +374,52 @@ // CHECK: ^bb1([[VAL_117:%.*]]: i32): // CHECK: return [[VAL_117]] : i32 // CHECK: ^bb2([[VAL_118:%.*]]: i32, [[VAL_119:%.*]]: i32, [[VAL_120:%.*]]: i32, [[VAL_121:%.*]]: i32): -// CHECK: [[VAL_122:%.*]] = addi [[VAL_118]], [[VAL_119]] : i32 -// CHECK: [[VAL_123:%.*]] = muli [[VAL_122]], [[VAL_120]] : i32 -// CHECK: [[VAL_124:%.*]] = addi [[VAL_123]], [[VAL_121]] : i32 +// CHECK: [[VAL_122:%.*]] = arith.addi [[VAL_118]], [[VAL_119]] : i32 +// CHECK: [[VAL_123:%.*]] = arith.muli [[VAL_122]], [[VAL_120]] : i32 +// CHECK: [[VAL_124:%.*]] = arith.addi [[VAL_123]], [[VAL_121]] : i32 // CHECK: return [[VAL_124]] : i32 // CHECK: ^bb3([[VAL_125:%.*]]: i32, [[VAL_126:%.*]]: i32): -// CHECK: [[VAL_127:%.*]] = addi [[VAL_125]], [[VAL_126]] : i32 +// CHECK: [[VAL_127:%.*]] = arith.addi [[VAL_125]], [[VAL_126]] : i32 // CHECK: return [[VAL_127]] : i32 // CHECK: ^bb4([[VAL_128:%.*]]: i32): // CHECK: return [[VAL_128]] : i32 ^bb1(%a : i32) : return %a : i32 ^bb2(%b : i32, %b2:i32, %b3:i32, %b4:i32) : - %4 = addi %b, %b2 : i32 - %5 = muli %4, %b3 : i32 - %6 = addi %5, %b4 : i32 + %4 = arith.addi %b, %b2 : i32 + %5 = arith.muli %4, %b3 : i32 + %6 = arith.addi %5, %b4 : i32 return %6 : i32 ^bb3(%c : i32, %c2 : i32) : - %7 = addi %c, %c2 : i32 + %7 = arith.addi %c, %c2 : i32 return %7 : i32 ^bb4(%d : i32) : return %d : i32 // CHECK: ^bb5: -// CHECK: [[VAL_129:%.*]] = constant 0 : i32 +// CHECK: [[VAL_129:%.*]] = arith.constant 0 : i32 // CHECK: return [[VAL_129]] : i32 // CHECK: } ^bb5 : - %zero = constant 0 : i32 + %zero = arith.constant 0 : i32 return %zero : i32 } // CHECK-LABEL: fir.global @global_var : i32 { -// CHECK: [[VAL_130:%.*]] = constant 1 : i32 +// CHECK: [[VAL_130:%.*]] = arith.constant 1 : i32 // CHECK: fir.has_value [[VAL_130]] : i32 // CHECK: } fir.global @global_var : i32 { - %0 = constant 1 : i32 + %0 = arith.constant 1 : i32 fir.has_value %0 : i32 } // CHECK-LABEL: fir.global @global_constant constant : i32 { -// CHECK: [[VAL_131:%.*]] = constant 934 : i32 +// CHECK: [[VAL_131:%.*]] = arith.constant 934 : i32 // CHECK: fir.has_value [[VAL_131]] : i32 // CHECK: } fir.global @global_constant constant : i32 { - %0 = constant 934 : i32 + %0 = arith.constant 934 : i32 fir.has_value %0 : i32 } @@ -489,20 +489,20 @@ // CHECK-SAME: [[VAL_169:%.*]]: f128, [[VAL_170:%.*]]: f128) -> f128 { func @arith_real(%a : f128, %b : f128) -> f128 { -// CHECK: [[VAL_171:%.*]] = constant 1.0 +// CHECK: [[VAL_171:%.*]] = arith.constant 1.0 // CHECK: [[VAL_172:%.*]] = fir.convert [[VAL_171]] : (f32) -> f128 -// CHECK: [[VAL_173:%.*]] = negf [[VAL_169]] : f128 -// CHECK: [[VAL_174:%.*]] = addf [[VAL_172]], [[VAL_173]] : f128 -// CHECK: [[VAL_175:%.*]] = subf [[VAL_174]], [[VAL_170]] : f128 -// CHECK: [[VAL_176:%.*]] = mulf [[VAL_173]], [[VAL_175]] : f128 -// CHECK: [[VAL_177:%.*]] = divf [[VAL_176]], [[VAL_169]] : f128 - %c1 = constant 1.0 : f32 +// CHECK: [[VAL_173:%.*]] = arith.negf [[VAL_169]] : f128 +// CHECK: [[VAL_174:%.*]] = arith.addf [[VAL_172]], [[VAL_173]] : f128 +// CHECK: [[VAL_175:%.*]] = arith.subf [[VAL_174]], [[VAL_170]] : f128 +// CHECK: [[VAL_176:%.*]] = arith.mulf [[VAL_173]], [[VAL_175]] : f128 +// CHECK: [[VAL_177:%.*]] = arith.divf [[VAL_176]], [[VAL_169]] : f128 + %c1 = arith.constant 1.0 : f32 %0 = fir.convert %c1 : (f32) -> f128 - %1 = negf %a : f128 - %2 = addf %0, %1 : f128 - %3 = subf %2, %b : f128 - %4 = mulf %1, %3 : f128 - %5 = divf %4, %a : f128 + %1 = arith.negf %a : f128 + %2 = arith.addf %0, %1 : f128 + %3 = arith.subf %2, %b : f128 + %4 = arith.mulf %1, %3 : f128 + %5 = arith.divf %4, %a : f128 // CHECK: return [[VAL_177]] : f128 // CHECK: } return %5 : f128 @@ -541,10 +541,10 @@ // CHECK-LABEL: func @early_exit( // CHECK-SAME: [[VAL_187:%.*]]: i1, [[VAL_188:%.*]]: i32) -> i1 { func @early_exit(%ok : i1, %k : i32) -> i1 { -// CHECK: [[VAL_189:%.*]] = constant 1 : index -// CHECK: [[VAL_190:%.*]] = constant 100 : index - %c1 = constant 1 : index - %c100 = constant 100 : index +// CHECK: [[VAL_189:%.*]] = arith.constant 1 : index +// CHECK: [[VAL_190:%.*]] = arith.constant 100 : index + %c1 = arith.constant 1 : index + %c100 = arith.constant 100 : index // CHECK: %[[VAL_191:.*]]:2 = fir.iterate_while ([[VAL_192:%.*]] = [[VAL_189]] to [[VAL_190]] step [[VAL_189]]) and ([[VAL_193:%.*]] = [[VAL_187]]) iter_args([[VAL_194:%.*]] = [[VAL_188]]) -> (i32) { // CHECK: [[VAL_195:%.*]] = call @earlyexit2([[VAL_194]]) : (i32) -> i1 @@ -561,29 +561,29 @@ // CHECK-LABEL: @array_access func @array_access(%arr : !fir.ref>) { - // CHECK-DAG: %[[c1:.*]] = constant 100 - // CHECK-DAG: %[[c2:.*]] = constant 50 - %c100 = constant 100 : index - %c50 = constant 50 : index + // CHECK-DAG: %[[c1:.*]] = arith.constant 100 + // CHECK-DAG: %[[c2:.*]] = arith.constant 50 + %c100 = arith.constant 100 : index + %c50 = arith.constant 50 : index // CHECK: %[[sh:.*]] = fir.shape %[[c1]], %[[c2]] : {{.*}} -> !fir.shape<2> %shape = fir.shape %c100, %c50 : (index, index) -> !fir.shape<2> - %c47 = constant 47 : index - %c78 = constant 78 : index - %c3 = constant 3 : index - %c18 = constant 18 : index - %c36 = constant 36 : index - %c4 = constant 4 : index + %c47 = arith.constant 47 : index + %c78 = arith.constant 78 : index + %c3 = arith.constant 3 : index + %c18 = arith.constant 18 : index + %c36 = arith.constant 36 : index + %c4 = arith.constant 4 : index // CHECK: %[[sl:.*]] = fir.slice {{.*}} -> !fir.slice<2> %slice = fir.slice %c47, %c78, %c3, %c18, %c36, %c4 : (index,index,index,index,index,index) -> !fir.slice<2> - %c0 = constant 0 : index - %c99 = constant 99 : index - %c1 = constant 1 : index + %c0 = arith.constant 0 : index + %c99 = arith.constant 99 : index + %c1 = arith.constant 1 : index fir.do_loop %i = %c0 to %c99 step %c1 { - %c49 = constant 49 : index + %c49 = arith.constant 49 : index fir.do_loop %j = %c0 to %c49 step %c1 { // CHECK: fir.array_coor %{{.*}}(%[[sh]]) [%[[sl]]] %{{.*}}, %{{.*}} : %p = fir.array_coor %arr(%shape)[%slice] %i, %j : (!fir.ref>, !fir.shape<2>, !fir.slice<2>, index, index) -> !fir.ref - %x = constant 42.0 : f32 + %x = arith.constant 42.0 : f32 fir.store %x to %p : !fir.ref } } @@ -607,16 +607,16 @@ // CHECK-LABEL: @test_misc_ops( // CHECK-SAME: [[ARR1:%.*]]: !fir.ref>, [[INDXM:%.*]]: index, [[INDXN:%.*]]: index, [[INDXO:%.*]]: index, [[INDXP:%.*]]: index) func @test_misc_ops(%arr1 : !fir.ref>, %m : index, %n : index, %o : index, %p : index) { - // CHECK: [[I10:%.*]] = constant 10 : index - // CHECK: [[J20:%.*]] = constant 20 : index - // CHECK: [[C2:%.*]] = constant 2 : index - // CHECK: [[C9:%.*]] = constant 9 : index - // CHECK: [[C1_I32:%.*]] = constant 9 : i32 - %i10 = constant 10 : index - %j20 = constant 20 : index - %c2 = constant 2 : index - %c9 = constant 9 : index - %c1_i32 = constant 9 : i32 + // CHECK: [[I10:%.*]] = arith.constant 10 : index + // CHECK: [[J20:%.*]] = arith.constant 20 : index + // CHECK: [[C2:%.*]] = arith.constant 2 : index + // CHECK: [[C9:%.*]] = arith.constant 9 : index + // CHECK: [[C1_I32:%.*]] = arith.constant 9 : i32 + %i10 = arith.constant 10 : index + %j20 = arith.constant 20 : index + %c2 = arith.constant 2 : index + %c9 = arith.constant 9 : index + %c1_i32 = arith.constant 9 : i32 // CHECK: [[ARR2:%.*]] = fir.zero_bits !fir.array<10xi32> // CHECK: [[ARR3:%.*]] = fir.insert_on_range [[ARR2]], [[C1_I32]], [2 : index, 9 : index] : (!fir.array<10xi32>, i32) -> !fir.array<10xi32> @@ -651,8 +651,8 @@ // CHECK-LABEL: @test_shift func @test_shift(%arg0: !fir.box>) -> !fir.ref { - %c4 = constant 4 : index - %c100 = constant 100 : index + %c4 = arith.constant 4 : index + %c100 = arith.constant 100 : index // CHECK: fir.shift %{{.*}} : (index) -> !fir.shift<1> %0 = fir.shift %c4 : (index) -> !fir.shift<1> %1 = fir.array_coor %arg0(%0) %c100 : (!fir.box>, !fir.shift<1>, index) -> !fir.ref @@ -662,13 +662,13 @@ func private @bar_rebox_test(!fir.box>) // CHECK-LABEL: @test_rebox( func @test_rebox(%arg0: !fir.box>) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %c2 = constant 2 : index - %c3 = constant 3 : index - %c4 = constant 4 : index - %c10 = constant 10 : index - %c33 = constant 33 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c10 = arith.constant 10 : index + %c33 = arith.constant 33 : index %0 = fir.slice %c10, %c33, %c2 : (index, index, index) -> !fir.slice<1> %1 = fir.shift %c0 : (index) -> !fir.shift<1> // CHECK: fir.rebox %{{.*}}(%{{.*}}) [%{{.*}}] : (!fir.box>, !fir.shift<1>, !fir.slice<1>) -> !fir.box> @@ -682,8 +682,8 @@ // CHECK-LABEL: @test_save_result( func @test_save_result(%buffer: !fir.ref>>) { - %c100 = constant 100 : index - %c50 = constant 50 : index + %c100 = arith.constant 100 : index + %c50 = arith.constant 50 : index %shape = fir.shape %c100 : (index) -> !fir.shape<1> %res = fir.call @array_func() : () -> !fir.array> // CHECK: fir.save_result %{{.*}} to %{{.*}}(%{{.*}}) typeparams %{{.*}} : !fir.array>, !fir.ref>>, !fir.shape<1>, index diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir --- a/flang/test/Fir/invalid.fir +++ b/flang/test/Fir/invalid.fir @@ -18,7 +18,7 @@ // ----- func @bad_rebox_1(%arg0: !fir.ref>) { - %c10 = constant 10 : index + %c10 = arith.constant 10 : index %0 = fir.shape %c10 : (index) -> !fir.shape<1> // expected-error@+1{{op operand #0 must be The type of a Fortran descriptor, but got '!fir.ref>'}} %1 = fir.rebox %arg0(%0) : (!fir.ref>, !fir.shape<1>) -> !fir.box> @@ -28,7 +28,7 @@ // ----- func @bad_rebox_2(%arg0: !fir.box>) { - %c10 = constant 10 : index + %c10 = arith.constant 10 : index %0 = fir.shape %c10 : (index) -> !fir.shape<1> // expected-error@+1{{op result #0 must be The type of a Fortran descriptor, but got '!fir.ref>'}} %1 = fir.rebox %arg0(%0) : (!fir.box>, !fir.shape<1>) -> !fir.ref> @@ -38,7 +38,7 @@ // ----- func @bad_rebox_3(%arg0: !fir.box>) { - %c10 = constant 10 : index + %c10 = arith.constant 10 : index %0 = fir.shape %c10 : (index) -> !fir.shape<1> // expected-error@+1{{op box operand must not have unknown rank or type}} %1 = fir.rebox %arg0(%0) : (!fir.box>, !fir.shape<1>) -> !fir.box> @@ -56,8 +56,8 @@ // ----- func @bad_rebox_5(%arg0: !fir.box>) { - %c1 = constant 1 : index - %c10 = constant 10 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index %0 = fir.slice %c1, %c10, %c1 : (index, index, index) -> !fir.slice<1> // expected-error@+1{{op slice operand rank must match box operand rank}} %1 = fir.rebox %arg0 [%0] : (!fir.box>, !fir.slice<1>) -> !fir.box> @@ -67,8 +67,8 @@ // ----- func @bad_rebox_6(%arg0: !fir.box>) { - %c1 = constant 1 : index - %c10 = constant 10 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index %0 = fir.slice %c1, %c10, %c1 : (index, index, index) -> !fir.slice<1> %1 = fir.shift %c1, %c1 : (index, index) -> !fir.shift<2> // expected-error@+1{{shape operand and input box ranks must match when there is a slice}} @@ -79,8 +79,8 @@ // ----- func @bad_rebox_7(%arg0: !fir.box>) { - %c1 = constant 1 : index - %c10 = constant 10 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index %0 = fir.slice %c1, %c10, %c1 : (index, index, index) -> !fir.slice<1> %1 = fir.shape %c10 : (index) -> !fir.shape<1> // expected-error@+1{{shape operand must absent or be a fir.shift when there is a slice}} @@ -91,8 +91,8 @@ // ----- func @bad_rebox_8(%arg0: !fir.box>) { - %c1 = constant 1 : index - %c10 = constant 10 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index %undef = fir.undefined index %0 = fir.slice %c1, %undef, %undef, %c1, %c10, %c1 : (index, index, index, index, index, index) -> !fir.slice<2> // expected-error@+1{{result type rank and rank after applying slice operand must match}} @@ -103,7 +103,7 @@ // ----- func @bad_rebox_9(%arg0: !fir.box>) { - %c10 = constant 10 : index + %c10 = arith.constant 10 : index %0 = fir.shift %c10, %c10 : (index, index) -> !fir.shift<2> // expected-error@+1{{shape operand and input box ranks must match when the shape is a fir.shift}} %1 = fir.rebox %arg0(%0) : (!fir.box>, !fir.shift<2>) -> !fir.box> @@ -113,7 +113,7 @@ // ----- func @bad_rebox_10(%arg0: !fir.box>) { - %c10 = constant 10 : index + %c10 = arith.constant 10 : index %0 = fir.shape %c10, %c10 : (index, index) -> !fir.shape<2> // expected-error@+1{{result type and shape operand ranks must match}} %1 = fir.rebox %arg0(%0) : (!fir.box>, !fir.shape<2>) -> !fir.box> @@ -123,7 +123,7 @@ // ----- func @bad_rebox_11(%arg0: !fir.box>) { - %c42 = constant 42 : index + %c42 = arith.constant 42 : index %0 = fir.shape %c42 : (index) -> !fir.shape<1> // expected-error@+1{{op input and output element types must match for intrinsic types}} %1 = fir.rebox %arg0(%0) : (!fir.box>, !fir.shape<1>) -> !fir.box> @@ -133,9 +133,9 @@ // ----- func @array_access(%arr : !fir.ref>) { - %c1 = constant 1 : index - %c100 = constant 100 : index - %c50 = constant 50 : index + %c1 = arith.constant 1 : index + %c100 = arith.constant 100 : index + %c50 = arith.constant 50 : index %shape = fir.shape %c100, %c50 : (index, index) -> !fir.shape<2> // expected-error@+1 {{'fir.array_coor' op operand #0 must be any reference or box, but got 'index'}} %p = fir.array_coor %c100(%shape) %c1, %c1 : (index, !fir.shape<2>, index, index) -> !fir.ref @@ -145,9 +145,9 @@ // ----- func @array_access(%arr : !fir.ref) { - %c1 = constant 1 : index - %c100 = constant 100 : index - %c50 = constant 50 : index + %c1 = arith.constant 1 : index + %c100 = arith.constant 100 : index + %c50 = arith.constant 50 : index %shape = fir.shape %c100, %c50 : (index, index) -> !fir.shape<2> // expected-error@+1 {{'fir.array_coor' op must be a reference to an array}} %p = fir.array_coor %arr(%shape) %c1, %c1 : (!fir.ref, !fir.shape<2>, index, index) -> !fir.ref @@ -157,13 +157,13 @@ // ----- func @array_access(%arr : !fir.ref>) { - %c1 = constant 1 : index - %c100 = constant 100 : index - %c50 = constant 50 : index + %c1 = arith.constant 1 : index + %c100 = arith.constant 100 : index + %c50 = arith.constant 50 : index %shape = fir.shape %c100, %c50 : (index, index) -> !fir.shape<2> - %c47 = constant 47 : index - %c78 = constant 78 : index - %c3 = constant 3 : index + %c47 = arith.constant 47 : index + %c78 = arith.constant 78 : index + %c3 = arith.constant 3 : index %slice = fir.slice %c47, %c78, %c3 : (index,index,index) -> !fir.slice<1> // expected-error@+1 {{'fir.array_coor' op rank of dimension in slice mismatched}} %p = fir.array_coor %arr(%shape)[%slice] %c1, %c1 : (!fir.ref>, !fir.shape<2>, !fir.slice<1>, index, index) -> !fir.ref @@ -173,8 +173,8 @@ // ----- func @array_access(%arr : !fir.ref>) { - %c1 = constant 1 : index - %c100 = constant 100 : index + %c1 = arith.constant 1 : index + %c100 = arith.constant 100 : index %shape = fir.shape %c100 : (index) -> !fir.shape<1> // expected-error@+1 {{'fir.array_coor' op rank of dimension mismatched}} %p = fir.array_coor %arr(%shape) %c1, %c1 : (!fir.ref>, !fir.shape<1>, index, index) -> !fir.ref @@ -184,8 +184,8 @@ // ----- func @array_access(%arr : !fir.ref>) { - %c1 = constant 1 : index - %c100 = constant 100 : index + %c1 = arith.constant 1 : index + %c100 = arith.constant 100 : index %shift = fir.shift %c1 : (index) -> !fir.shift<1> // expected-error@+1 {{'fir.array_coor' op shift can only be provided with fir.box memref}} %p = fir.array_coor %arr(%shift) %c1, %c1 : (!fir.ref>, !fir.shift<1>, index, index) -> !fir.ref @@ -195,9 +195,9 @@ // ----- func @array_access(%arr : !fir.ref>) { - %c1 = constant 1 : index - %c100 = constant 100 : index - %c50 = constant 50 : index + %c1 = arith.constant 1 : index + %c100 = arith.constant 100 : index + %c50 = arith.constant 50 : index %shape = fir.shape %c100, %c50 : (index, index) -> !fir.shape<2> // expected-error@+1 {{'fir.array_coor' op number of indices do not match dim rank}} %p = fir.array_coor %arr(%shape) %c1 : (!fir.ref>, !fir.shape<2>, index) -> !fir.ref @@ -207,7 +207,7 @@ // ----- func @test_misc_ops(%arr1 : !fir.ref>, %m : index, %n : index, %o : index, %p : index) { - %c2 = constant 2 : index + %c2 = arith.constant 2 : index %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2> // expected-error@+1 {{'fir.array_load' op operand #0 must be any reference or box, but got 'index'}} %av1 = fir.array_load %c2(%s) : (index, !fir.shapeshift<2>) -> !fir.array @@ -235,7 +235,7 @@ // ----- func @test_misc_ops(%arr1 : !fir.ref>, %m : index, %n : index, %o : index, %p : index) { - %c2 = constant 2 : index + %c2 = arith.constant 2 : index %shift = fir.shift %c2 : (index) -> !fir.shift<1> // expected-error@+1 {{'fir.array_load' op shift can only be provided with fir.box memref}} %av1 = fir.array_load %arr1(%shift) : (!fir.ref>, !fir.shift<1>) -> !fir.array @@ -245,9 +245,9 @@ // ----- func @test_misc_ops(%arr1 : !fir.ref>, %m : index, %n : index, %o : index, %p : index) { - %c47 = constant 47 : index - %c78 = constant 78 : index - %c3 = constant 3 : index + %c47 = arith.constant 47 : index + %c78 = arith.constant 78 : index + %c3 = arith.constant 3 : index %slice = fir.slice %c47, %c78, %c3 : (index,index,index) -> !fir.slice<1> %s = fir.shape_shift %m, %n, %o, %p: (index, index, index, index) -> !fir.shapeshift<2> // expected-error@+1 {{'fir.array_load' op rank of dimension in slice mismatched}} @@ -258,7 +258,7 @@ // ----- func @test_coordinate_of(%arr : !fir.ref>) { - %1 = constant 10 : i32 + %1 = arith.constant 10 : i32 // expected-error@+1 {{'fir.coordinate_of' op cannot find coordinate with unknown extents}} %2 = fir.coordinate_of %arr, %1 : (!fir.ref>, i32) -> !fir.ref return @@ -267,7 +267,7 @@ // ----- func @test_coordinate_of(%arr : !fir.ref>) { - %1 = constant 10 : i32 + %1 = arith.constant 10 : i32 // expected-error@+1 {{'fir.coordinate_of' op cannot find coordinate in unknown shape}} %2 = fir.coordinate_of %arr, %1 : (!fir.ref>, i32) -> !fir.ref return @@ -276,7 +276,7 @@ // ----- func @test_coordinate_of(%arr : !fir.ref>) { - %1 = constant 10 : i32 + %1 = arith.constant 10 : i32 // expected-error@+1 {{'fir.coordinate_of' op cannot apply coordinate_of to this type}} %2 = fir.coordinate_of %arr, %1 : (!fir.ref>, i32) -> !fir.ref return @@ -284,14 +284,14 @@ // ----- -%0 = constant 22 : i32 +%0 = arith.constant 22 : i32 // expected-error@+1 {{'fir.embox' op operand #0 must be any reference, but got 'i32'}} %1 = fir.embox %0 : (i32) -> !fir.box // ----- func @fun(%0 : !fir.ref) { - %c_100 = constant 100 : index + %c_100 = arith.constant 100 : index %1 = fir.shape %c_100 : (index) -> !fir.shape<1> // expected-error@+1 {{'fir.embox' op shape must not be provided for a scalar}} %2 = fir.embox %0(%1) : (!fir.ref, !fir.shape<1>) -> !fir.box @@ -300,7 +300,7 @@ // ----- func @fun(%0 : !fir.ref) { - %c_100 = constant 100 : index + %c_100 = arith.constant 100 : index %1 = fir.slice %c_100, %c_100, %c_100 : (index, index, index) -> !fir.slice<1> // expected-error@+1 {{'fir.embox' op operand #1 must be any legal shape type, but got '!fir.slice<1>'}} %2 = fir.embox %0(%1) : (!fir.ref, !fir.slice<1>) -> !fir.box @@ -309,7 +309,7 @@ // ----- func @fun(%0 : !fir.ref) { - %c_100 = constant 100 : index + %c_100 = arith.constant 100 : index %1 = fir.shape %c_100 : (index) -> !fir.shape<1> // expected-error@+1 {{'fir.embox' op operand #1 must be FIR slice, but got '!fir.shape<1>'}} %2 = fir.embox %0[%1] : (!fir.ref, !fir.shape<1>) -> !fir.box @@ -318,7 +318,7 @@ // ----- func @fun(%0 : !fir.ref) { - %c_100 = constant 100 : index + %c_100 = arith.constant 100 : index %1 = fir.slice %c_100, %c_100, %c_100 : (index, index, index) -> !fir.slice<1> // expected-error@+1 {{'fir.embox' op slice must not be provided for a scalar}} %2 = fir.embox %0[%1] : (!fir.ref, !fir.slice<1>) -> !fir.box @@ -326,11 +326,11 @@ // ----- -%lo = constant 1 : index -%c1 = constant 1 : index -%up = constant 10 : index -%okIn = constant 1 : i1 -%shIn = constant 1 : i16 +%lo = arith.constant 1 : index +%c1 = arith.constant 1 : index +%up = arith.constant 10 : index +%okIn = arith.constant 1 : i1 +%shIn = arith.constant 1 : i16 // expected-error@+1 {{'fir.iterate_while' op expected body first argument to be an index argument for the induction variable}} %v:3 = fir.iterate_while (%i = %lo to %up step %c1) and (%ok = %okIn) iter_args(%sh = %shIn) -> (i16, i1, i16) { %shNew = fir.call @bar(%sh) : (i16) -> i16 @@ -340,11 +340,11 @@ // ----- -%lo = constant 1 : index -%c1 = constant 1 : index -%up = constant 10 : index -%okIn = constant 1 : i1 -%shIn = constant 1 : i16 +%lo = arith.constant 1 : index +%c1 = arith.constant 1 : index +%up = arith.constant 10 : index +%okIn = arith.constant 1 : i1 +%shIn = arith.constant 1 : i16 // expected-error@+1 {{'fir.iterate_while' op expected body second argument to be an index argument for the induction variable}} %v:3 = fir.iterate_while (%i = %lo to %up step %c1) and (%ok = %okIn) iter_args(%sh = %shIn) -> (index, f32, i16) { %shNew = fir.call @bar(%sh) : (i16) -> i16 @@ -354,26 +354,26 @@ // ----- -%c1 = constant 1 : index -%c10 = constant 10 : index +%c1 = arith.constant 1 : index +%c10 = arith.constant 10 : index // expected-error@+1 {{'fir.do_loop' op unordered loop has no final value}} fir.do_loop %i = %c1 to %c10 step %c1 unordered -> index { } // ----- -%c1 = constant 1 : index -%c10 = constant 10 : index +%c1 = arith.constant 1 : index +%c10 = arith.constant 10 : index fir.do_loop %i = %c1 to %c10 step %c1 -> index { - %f1 = constant 1.0 : f32 + %f1 = arith.constant 1.0 : f32 // expected-error@+1 {{'fir.result' op types mismatch between result op and its parent}} fir.result %f1 : f32 } // ----- -%c1 = constant 1 : index -%c10 = constant 10 : index +%c1 = arith.constant 1 : index +%c10 = arith.constant 10 : index // expected-error@+1 {{'fir.result' op parent of result must have same arity}} fir.do_loop %i = %c1 to %c10 step %c1 -> index { } @@ -425,7 +425,7 @@ // ----- fir.global internal @_QEmultiarray : !fir.array<32x32xi32> { - %c0_i32 = constant 1 : i32 + %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array<32x32xi32> // expected-error@+1 {{'fir.insert_on_range' op has uneven number of values in ranges}} %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 31 : index, 0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> @@ -435,7 +435,7 @@ // ----- fir.global internal @_QEmultiarray : !fir.array<32x32xi32> { - %c0_i32 = constant 1 : i32 + %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array<32x32xi32> // expected-error@+1 {{'fir.insert_on_range' op has uneven number of values in ranges}} %2 = fir.insert_on_range %0, %c0_i32, [0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> @@ -445,7 +445,7 @@ // ----- fir.global internal @_QEmultiarray : !fir.array<32x32xi32> { - %c0_i32 = constant 1 : i32 + %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array<32x32xi32> // expected-error@+1 {{'fir.insert_on_range' op negative range bound}} %2 = fir.insert_on_range %0, %c0_i32, [-1 : index, 0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> @@ -455,7 +455,7 @@ // ----- fir.global internal @_QEmultiarray : !fir.array<32x32xi32> { - %c0_i32 = constant 1 : i32 + %c0_i32 = arith.constant 1 : i32 %0 = fir.undefined !fir.array<32x32xi32> // expected-error@+1 {{'fir.insert_on_range' op empty range}} %2 = fir.insert_on_range %0, %c0_i32, [10 : index, 9 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> @@ -575,7 +575,7 @@ func @test_misc_ops(%arr1 : !fir.ref>, %m : index, %n : index, %o : index, %p : index) { %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2> %av1 = fir.array_load %arr1(%s) : (!fir.ref>, !fir.shapeshift<2>) -> !fir.array - %c0 = constant 0 : i32 + %c0 = arith.constant 0 : i32 // expected-error@+1 {{'fir.array_update' op merged value does not have element type}} %av2 = fir.array_update %av1, %c0, %m, %n : (!fir.array, i32, index, index) -> !fir.array return @@ -596,8 +596,8 @@ // ----- func @bad_array_modify(%arr1 : !fir.ref>, %m : index, %n : index, %o : index, %p : index, %f : f32) { - %i10 = constant 10 : index - %j20 = constant 20 : index + %i10 = arith.constant 10 : index + %j20 = arith.constant 20 : index %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2> %av1 = fir.array_load %arr1(%s) : (!fir.ref>, !fir.shapeshift<2>) -> !fir.array // expected-error@+1 {{'fir.array_modify' op number of indices must match array dimension}} diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -8,22 +8,22 @@ ### Pre-requisites -* A relatively recent Python3 installation -* Installation of python dependencies as specified in - `mlir/python/requirements.txt` +* A relatively recent Python3 installation +* Installation of python dependencies as specified in + `mlir/python/requirements.txt` ### CMake variables -* **`MLIR_ENABLE_BINDINGS_PYTHON`**`:BOOL` +* **`MLIR_ENABLE_BINDINGS_PYTHON`**`:BOOL` - Enables building the Python bindings. Defaults to `OFF`. + Enables building the Python bindings. Defaults to `OFF`. -* **`Python3_EXECUTABLE`**:`STRING` +* **`Python3_EXECUTABLE`**:`STRING` - Specifies the `python` executable used for the LLVM build, including for - determining header/link flags for the Python bindings. On systems with - multiple Python implementations, setting this explicitly to the preferred - `python3` executable is strongly recommended. + Specifies the `python` executable used for the LLVM build, including for + determining header/link flags for the Python bindings. On systems with + multiple Python implementations, setting this explicitly to the preferred + `python3` executable is strongly recommended. ### Recommended development practices @@ -62,8 +62,8 @@ export PYTHONPATH=$(cd build && pwd)/tools/mlir/python_packages/mlir_core ``` -Note that if you have installed (i.e. via `ninja install`, et al), then -python packages for all enabled projects will be in your install tree under +Note that if you have installed (i.e. via `ninja install`, et al), then python +packages for all enabled projects will be in your install tree under `python_packages/` (i.e. `python_packages/mlir_core`). Official distributions are built with a more specialized setup. @@ -73,12 +73,12 @@ There are likely two primary use cases for the MLIR python bindings: -1. Support users who expect that an installed version of LLVM/MLIR will yield - the ability to `import mlir` and use the API in a pure way out of the box. +1. Support users who expect that an installed version of LLVM/MLIR will yield + the ability to `import mlir` and use the API in a pure way out of the box. -1. Downstream integrations will likely want to include parts of the API in their - private namespace or specially built libraries, probably mixing it with other - python native bits. +1. Downstream integrations will likely want to include parts of the API in + their private namespace or specially built libraries, probably mixing it + with other python native bits. ### Composable modules @@ -86,15 +86,15 @@ composable modules that downstream integrators can include and re-export into their own namespace if desired. This forces several design points: -* Separate the construction/populating of a `py::module` from `PYBIND11_MODULE` - global constructor. +* Separate the construction/populating of a `py::module` from + `PYBIND11_MODULE` global constructor. -* Introduce headers for C++-only wrapper classes as other related C++ modules - will need to interop with it. +* Introduce headers for C++-only wrapper classes as other related C++ modules + will need to interop with it. -* Separate any initialization routines that depend on optional components into - its own module/dependency (currently, things like `registerAllDialects` fall - into this category). +* Separate any initialization routines that depend on optional components into + its own module/dependency (currently, things like `registerAllDialects` fall + into this category). There are a lot of co-related issues of shared library linkage, distribution concerns, etc that affect such things. Organizing the code into composable @@ -113,17 +113,17 @@ Examples: -* `mlir.ir` -* `mlir.passes` (`pass` is a reserved word :( ) -* `mlir.dialect` -* `mlir.execution_engine` (aside from namespacing, it is important that - "bulky"/optional parts like this are isolated) +* `mlir.ir` +* `mlir.passes` (`pass` is a reserved word :( ) +* `mlir.dialect` +* `mlir.execution_engine` (aside from namespacing, it is important that + "bulky"/optional parts like this are isolated) -In addition, initialization functions that imply optional dependencies should -be in underscored (notionally private) modules such as `_init` and linked +In addition, initialization functions that imply optional dependencies should be +in underscored (notionally private) modules such as `_init` and linked separately. This allows downstream integrators to completely customize what is -included "in the box" and covers things like dialect registration, -pass registration, etc. +included "in the box" and covers things like dialect registration, pass +registration, etc. ### Loader @@ -131,17 +131,16 @@ other non-trivial native extensions. As such, the native extension (i.e. the `.so`/`.pyd`/`.dylib`) is exported as a notionally private top-level symbol (`_mlir`), while a small set of Python code is provided in -`mlir/_cext_loader.py` and siblings which loads and re-exports it. This -split provides a place to stage code that needs to prepare the environment -*before* the shared library is loaded into the Python runtime, and also -provides a place that one-time initialization code can be invoked apart from -module constructors. +`mlir/_cext_loader.py` and siblings which loads and re-exports it. This split +provides a place to stage code that needs to prepare the environment *before* +the shared library is loaded into the Python runtime, and also provides a place +that one-time initialization code can be invoked apart from module constructors. It is recommended to avoid using `__init__.py` files to the extent possible, -until reaching a leaf package that represents a discrete component. The rule -to keep in mind is that the presence of an `__init__.py` file prevents the -ability to split anything at that level or below in the namespace into -different directories, deployment packages, wheels, etc. +until reaching a leaf package that represents a discrete component. The rule to +keep in mind is that the presence of an `__init__.py` file prevents the ability +to split anything at that level or below in the namespace into different +directories, deployment packages, wheels, etc. See the documentation for more information and advice: https://packaging.python.org/guides/packaging-namespace-packages/ @@ -157,11 +156,12 @@ ### Ownership in the Core IR -There are several top-level types in the core IR that are strongly owned by their python-side reference: +There are several top-level types in the core IR that are strongly owned by +their python-side reference: -* `PyContext` (`mlir.ir.Context`) -* `PyModule` (`mlir.ir.Module`) -* `PyOperation` (`mlir.ir.Operation`) - but with caveats +* `PyContext` (`mlir.ir.Context`) +* `PyModule` (`mlir.ir.Module`) +* `PyOperation` (`mlir.ir.Operation`) - but with caveats All other objects are dependent. All objects maintain a back-reference (keep-alive) to their closest containing top-level object. Further, dependent @@ -173,11 +173,12 @@ ### Optionality and argument ordering in the Core IR -The following types support being bound to the current thread as a context manager: +The following types support being bound to the current thread as a context +manager: -* `PyLocation` (`loc: mlir.ir.Location = None`) -* `PyInsertionPoint` (`ip: mlir.ir.InsertionPoint = None`) -* `PyMlirContext` (`context: mlir.ir.Context = None`) +* `PyLocation` (`loc: mlir.ir.Location = None`) +* `PyInsertionPoint` (`ip: mlir.ir.InsertionPoint = None`) +* `PyMlirContext` (`context: mlir.ir.Context = None`) In order to support composability of function arguments, when these types appear as arguments, they should always be the last and appear in the above order and @@ -692,9 +693,9 @@ m.def("getContext", ...) ``` -### __repr__ methods +### **repr** methods -Things that have nice printed representations are really great :) If there is a +Things that have nice printed representations are really great :) If there is a reasonable printed form, it can be a significant productivity boost to wire that to the `__repr__` method (and verify it with a [doctest](#sample-doctest)). @@ -759,14 +760,14 @@ We use `lit` and `FileCheck` based tests: -* For generative tests (those that produce IR), define a Python module that - constructs/prints the IR and pipe it through `FileCheck`. -* Parsing should be kept self-contained within the module under test by use of - raw constants and an appropriate `parse_asm` call. -* Any file I/O code should be staged through a tempfile vs relying on file - artifacts/paths outside of the test module. -* For convenience, we also test non-generative API interactions with the same - mechanisms, printing and `CHECK`ing as needed. +* For generative tests (those that produce IR), define a Python module that + constructs/prints the IR and pipe it through `FileCheck`. +* Parsing should be kept self-contained within the module under test by use of + raw constants and an appropriate `parse_asm` call. +* Any file I/O code should be staged through a tempfile vs relying on file + artifacts/paths outside of the test module. +* For convenience, we also test non-generative API interactions with the same + mechanisms, printing and `CHECK`ing as needed. ### Sample FileCheck test @@ -794,13 +795,13 @@ ## Integration with ODS The MLIR Python bindings integrate with the tablegen-based ODS system for -providing user-friendly wrappers around MLIR dialects and operations. There -are multiple parts to this integration, outlined below. Most details have -been elided: refer to the build rules and python sources under `mlir.dialects` -for the canonical way to use this facility. +providing user-friendly wrappers around MLIR dialects and operations. There are +multiple parts to this integration, outlined below. Most details have been +elided: refer to the build rules and python sources under `mlir.dialects` for +the canonical way to use this facility. -Users are responsible for providing a `{DIALECT_NAMESPACE}.py` (or an -equivalent directory with `__init__.py` file) as the entrypoint. +Users are responsible for providing a `{DIALECT_NAMESPACE}.py` (or an equivalent +directory with `__init__.py` file) as the entrypoint. ### Generating `_{DIALECT_NAMESPACE}_ops_gen.py` wrapper modules @@ -838,10 +839,10 @@ ### Extending the search path for wrapper modules When the python bindings need to locate a wrapper module, they consult the -`dialect_search_path` and use it to find an appropriately named module. For -the main repository, this search path is hard-coded to include the -`mlir.dialects` module, which is where wrappers are emitted by the abobe build -rule. Out of tree dialects and add their modules to the search path by calling: +`dialect_search_path` and use it to find an appropriately named module. For the +main repository, this search path is hard-coded to include the `mlir.dialects` +module, which is where wrappers are emitted by the abobe build rule. Out of tree +dialects and add their modules to the search path by calling: ```python mlir._cext.append_dialect_search_prefix("myproject.mlir.dialects") @@ -851,10 +852,10 @@ The wrapper module tablegen emitter outputs: -* A `_Dialect` class (extending `mlir.ir.Dialect`) with a `DIALECT_NAMESPACE` - attribute. -* An `{OpName}` class for each operation (extending `mlir.ir.OpView`). -* Decorators for each of the above to register with the system. +* A `_Dialect` class (extending `mlir.ir.Dialect`) with a `DIALECT_NAMESPACE` + attribute. +* An `{OpName}` class for each operation (extending `mlir.ir.OpView`). +* Decorators for each of the above to register with the system. Note: In order to avoid naming conflicts, all internal names used by the wrapper module are prefixed by `_ods_`. @@ -862,54 +863,54 @@ Each concrete `OpView` subclass further defines several public-intended attributes: -* `OPERATION_NAME` attribute with the `str` fully qualified operation name - (i.e. `std.absf`). -* An `__init__` method for the *default builder* if one is defined or inferred - for the operation. -* `@property` getter for each operand or result (using an auto-generated name - for unnamed of each). -* `@property` getter, setter and deleter for each declared attribute. +* `OPERATION_NAME` attribute with the `str` fully qualified operation name + (i.e. `math.abs`). +* An `__init__` method for the *default builder* if one is defined or inferred + for the operation. +* `@property` getter for each operand or result (using an auto-generated name + for unnamed of each). +* `@property` getter, setter and deleter for each declared attribute. It further emits additional private-intended attributes meant for subclassing -and customization (default cases omit these attributes in favor of the -defaults on `OpView`): - -* `_ODS_REGIONS`: A specification on the number and types of regions. - Currently a tuple of (min_region_count, has_no_variadic_regions). Note that - the API does some light validation on this but the primary purpose is to - capture sufficient information to perform other default building and region - accessor generation. -* `_ODS_OPERAND_SEGMENTS` and `_ODS_RESULT_SEGMENTS`: Black-box value which - indicates the structure of either the operand or results with respect to - variadics. Used by `OpView._ods_build_default` to decode operand and result - lists that contain lists. +and customization (default cases omit these attributes in favor of the defaults +on `OpView`): + +* `_ODS_REGIONS`: A specification on the number and types of regions. + Currently a tuple of (min_region_count, has_no_variadic_regions). Note that + the API does some light validation on this but the primary purpose is to + capture sufficient information to perform other default building and region + accessor generation. +* `_ODS_OPERAND_SEGMENTS` and `_ODS_RESULT_SEGMENTS`: Black-box value which + indicates the structure of either the operand or results with respect to + variadics. Used by `OpView._ods_build_default` to decode operand and result + lists that contain lists. #### Default Builder Presently, only a single, default builder is mapped to the `__init__` method. -The intent is that this `__init__` method represents the *most specific* of -the builders typically generated for C++; however currently it is just the -generic form below. - -* One argument for each declared result: - * For single-valued results: Each will accept an `mlir.ir.Type`. - * For variadic results: Each will accept a `List[mlir.ir.Type]`. -* One argument for each declared operand or attribute: - * For single-valued operands: Each will accept an `mlir.ir.Value`. - * For variadic operands: Each will accept a `List[mlir.ir.Value]`. - * For attributes, it will accept an `mlir.ir.Attribute`. -* Trailing usage-specific, optional keyword arguments: - * `loc`: An explicit `mlir.ir.Location` to use. Defaults to the location - bound to the thread (i.e. `with Location.unknown():`) or an error if none - is bound nor specified. - * `ip`: An explicit `mlir.ir.InsertionPoint` to use. Default to the insertion - point bound to the thread (i.e. `with InsertionPoint(...):`). +The intent is that this `__init__` method represents the *most specific* of the +builders typically generated for C++; however currently it is just the generic +form below. + +* One argument for each declared result: + * For single-valued results: Each will accept an `mlir.ir.Type`. + * For variadic results: Each will accept a `List[mlir.ir.Type]`. +* One argument for each declared operand or attribute: + * For single-valued operands: Each will accept an `mlir.ir.Value`. + * For variadic operands: Each will accept a `List[mlir.ir.Value]`. + * For attributes, it will accept an `mlir.ir.Attribute`. +* Trailing usage-specific, optional keyword arguments: + * `loc`: An explicit `mlir.ir.Location` to use. Defaults to the location + bound to the thread (i.e. `with Location.unknown():`) or an error if + none is bound nor specified. + * `ip`: An explicit `mlir.ir.InsertionPoint` to use. Default to the + insertion point bound to the thread (i.e. `with InsertionPoint(...):`). In addition, each `OpView` inherits a `build_generic` method which allows construction via a (nested in the case of variadic) sequence of `results` and `operands`. This can be used to get some default construction semantics for -operations that are otherwise unsupported in Python, at the expense of having -a very generic signature. +operations that are otherwise unsupported in Python, at the expense of having a +very generic signature. #### Extending Generated Op Classes @@ -919,15 +920,15 @@ provides some relatively simple examples. As mentioned above, the build system generates Python sources like -`_{DIALECT_NAMESPACE}_ops_gen.py` for each dialect with Python bindings. It -is often desirable to to use these generated classes as a starting point for -further customization, so an extension mechanism is provided to make this -easy (you are always free to do ad-hoc patching in your `{DIALECT_NAMESPACE}.py` -file but we prefer a more standard mechanism that is applied uniformly). +`_{DIALECT_NAMESPACE}_ops_gen.py` for each dialect with Python bindings. It is +often desirable to to use these generated classes as a starting point for +further customization, so an extension mechanism is provided to make this easy +(you are always free to do ad-hoc patching in your `{DIALECT_NAMESPACE}.py` file +but we prefer a more standard mechanism that is applied uniformly). To provide extensions, add a `_{DIALECT_NAMESPACE}_ops_ext.py` file to the -`dialects` module (i.e. adjacent to your `{DIALECT_NAMESPACE}.py` top-level -and the `*_ops_gen.py` file). Using the `builtin` dialect and `FuncOp` as an +`dialects` module (i.e. adjacent to your `{DIALECT_NAMESPACE}.py` top-level and +the `*_ops_gen.py` file). Using the `builtin` dialect and `FuncOp` as an example, the generated code will include an import like this: ```python @@ -949,41 +950,41 @@ See the `_ods_common.py` `extend_opview_class` function for details of the mechanism. At a high level: -* If the extension module exists, locate an extension class for the op (in - this example, `FuncOp`): - * First by looking for an attribute with the exact name in the extension - module. - * Falling back to calling a `select_opview_mixin(parent_opview_cls)` - function defined in the extension module. -* If a mixin class is found, a new subclass is dynamically created that multiply - inherits from `({_builtin_ops_ext.FuncOp}, _builtin_ops_gen.FuncOp)`. - -The mixin class should not inherit from anything (i.e. directly extends -`object` only). The facility is typically used to define custom `__init__` -methods, properties, instance methods and static methods. Due to the -inheritance ordering, the mixin class can act as though it extends the -generated `OpView` subclass in most contexts (i.e. -`issubclass(_builtin_ops_ext.FuncOp, OpView)` will return `False` but usage -generally allows you treat it as duck typed as an `OpView`). - -There are a couple of recommendations, given how the class hierarchy is -defined: - -* For static methods that need to instantiate the actual "leaf" op (which - is dynamically generated and would result in circular dependencies to try - to reference by name), prefer to use `@classmethod` and the concrete - subclass will be provided as your first `cls` argument. See - `_builtin_ops_ext.FuncOp.from_py_func` as an example. -* If seeking to replace the generated `__init__` method entirely, you may - actually want to invoke the super-super-class `mlir.ir.OpView` constructor - directly, as it takes an `mlir.ir.Operation`, which is likely what you - are constructing (i.e. the generated `__init__` method likely adds more - API constraints than you want to expose in a custom builder). +* If the extension module exists, locate an extension class for the op (in + this example, `FuncOp`): + * First by looking for an attribute with the exact name in the extension + module. + * Falling back to calling a `select_opview_mixin(parent_opview_cls)` + function defined in the extension module. +* If a mixin class is found, a new subclass is dynamically created that + multiply inherits from `({_builtin_ops_ext.FuncOp}, + _builtin_ops_gen.FuncOp)`. + +The mixin class should not inherit from anything (i.e. directly extends `object` +only). The facility is typically used to define custom `__init__` methods, +properties, instance methods and static methods. Due to the inheritance +ordering, the mixin class can act as though it extends the generated `OpView` +subclass in most contexts (i.e. `issubclass(_builtin_ops_ext.FuncOp, OpView)` +will return `False` but usage generally allows you treat it as duck typed as an +`OpView`). + +There are a couple of recommendations, given how the class hierarchy is defined: + +* For static methods that need to instantiate the actual "leaf" op (which is + dynamically generated and would result in circular dependencies to try to + reference by name), prefer to use `@classmethod` and the concrete subclass + will be provided as your first `cls` argument. See + `_builtin_ops_ext.FuncOp.from_py_func` as an example. +* If seeking to replace the generated `__init__` method entirely, you may + actually want to invoke the super-super-class `mlir.ir.OpView` constructor + directly, as it takes an `mlir.ir.Operation`, which is likely what you are + constructing (i.e. the generated `__init__` method likely adds more API + constraints than you want to expose in a custom builder). A pattern that comes up frequently is wanting to provide a sugared `__init__` method which has optional or type-polymorphism/implicit conversions but to -otherwise want to invoke the default op building logic. For such cases, -it is recommended to use an idiom such as: +otherwise want to invoke the default op building logic. For such cases, it is +recommended to use an idiom such as: ```python def __init__(self, sugar, spice, *, loc=None, ip=None): diff --git a/mlir/docs/BufferDeallocationInternals.md b/mlir/docs/BufferDeallocationInternals.md --- a/mlir/docs/BufferDeallocationInternals.md +++ b/mlir/docs/BufferDeallocationInternals.md @@ -7,34 +7,34 @@ ## Requirements -In order to use BufferDeallocation on an arbitrary dialect, several -control-flow interfaces have to be implemented when using custom operations. -This is particularly important to understand the implicit control-flow -dependencies between different parts of the input program. Without implementing -the following interfaces, control-flow relations cannot be discovered properly -and the resulting program can become invalid: - -* Branch-like terminators should implement the `BranchOpInterface` to query and -manipulate associated operands. -* Operations involving structured control flow have to implement the -`RegionBranchOpInterface` to model inter-region control flow. -* Terminators yielding values to their parent operation (in particular in the -scope of nested regions within `RegionBranchOpInterface`-based operations), -should implement the `ReturnLike` trait to represent logical “value returns”. - -Example dialects that are fully compatible are the “std” and “scf” dialects -with respect to all implemented interfaces. +In order to use BufferDeallocation on an arbitrary dialect, several control-flow +interfaces have to be implemented when using custom operations. This is +particularly important to understand the implicit control-flow dependencies +between different parts of the input program. Without implementing the following +interfaces, control-flow relations cannot be discovered properly and the +resulting program can become invalid: + +* Branch-like terminators should implement the `BranchOpInterface` to query + and manipulate associated operands. +* Operations involving structured control flow have to implement the + `RegionBranchOpInterface` to model inter-region control flow. +* Terminators yielding values to their parent operation (in particular in the + scope of nested regions within `RegionBranchOpInterface`-based operations), + should implement the `ReturnLike` trait to represent logical “value + returns”. + +Example dialects that are fully compatible are the “std” and “scf” dialects with +respect to all implemented interfaces. During Bufferization, we convert immutable value types (tensors) to mutable types (memref). This conversion is done in several steps and in all of these -steps the IR has to fulfill SSA like properties. The usage of memref has -to be in the following consecutive order: allocation, write-buffer, read- -buffer. -In this case, there are only buffer reads allowed after the initial full -buffer write is done. In particular, there must be no partial write to a -buffer after the initial write has been finished. However, partial writes in -the initializing is allowed (fill buffer step by step in a loop e.g.). This -means, all buffer writes needs to dominate all buffer reads. +steps the IR has to fulfill SSA like properties. The usage of memref has to be +in the following consecutive order: allocation, write-buffer, read- buffer. In +this case, there are only buffer reads allowed after the initial full buffer +write is done. In particular, there must be no partial write to a buffer after +the initial write has been finished. However, partial writes in the initializing +is allowed (fill buffer step by step in a loop e.g.). This means, all buffer +writes needs to dominate all buffer reads. Example for breaking the invariant: @@ -65,15 +65,15 @@ particular result value while not using the resource `SideEffects::AutomaticAllocationScopeResource` (since it is currently reserved for allocations, like `Alloca` that will be automatically deallocated by a -parent scope). Allocations that have not been detected in this phase will not -be tracked internally, and thus, not deallocated automatically. However, -BufferDeallocation is fully compatible with “hybrid” setups in which tracked -and untracked allocations are mixed: +parent scope). Allocations that have not been detected in this phase will not be +tracked internally, and thus, not deallocated automatically. However, +BufferDeallocation is fully compatible with “hybrid” setups in which tracked and +untracked allocations are mixed: ```mlir func @mixedAllocation(%arg0: i1) { - %0 = alloca() : memref<2xf32> // aliases: %2 - %1 = alloc() : memref<2xf32> // aliases: %2 + %0 = memref.alloca() : memref<2xf32> // aliases: %2 + %1 = memref.alloc() : memref<2xf32> // aliases: %2 cond_br %arg0, ^bb1, ^bb2 ^bb1: use(%0) @@ -98,29 +98,29 @@ some cases, it can be useful to use such stack-based buffers instead of heap-based buffers. The conversion is restricted to several constraints like: -* Control flow -* Buffer Size -* Dynamic Size +* Control flow +* Buffer Size +* Dynamic Size -If a buffer is leaving a block, we are not allowed to convert it into an -alloca. If the size of the buffer is large, we could convert it, but regarding -stack overflow, it makes sense to limit the size of these buffers and only -convert small ones. The size can be set via a pass option. The current default -value is 1KB. Furthermore, we can not convert buffers with dynamic size, since -the dimension is not known a priori. +If a buffer is leaving a block, we are not allowed to convert it into an alloca. +If the size of the buffer is large, we could convert it, but regarding stack +overflow, it makes sense to limit the size of these buffers and only convert +small ones. The size can be set via a pass option. The current default value is +1KB. Furthermore, we can not convert buffers with dynamic size, since the +dimension is not known a priori. ## Movement and Placement of Allocations Using the buffer hoisting pass, all buffer allocations are moved as far upwards as possible in order to group them and make upcoming optimizations easier by -limiting the search space. Such a movement is shown in the following graphs. -In addition, we are able to statically free an alloc, if we move it into a -dominator of all of its uses. This simplifies further optimizations (e.g. -buffer fusion) in the future. However, movement of allocations is limited by -external data dependencies (in particular in the case of allocations of -dynamically shaped types). Furthermore, allocations can be moved out of nested -regions, if necessary. In order to move allocations to valid locations with -respect to their uses only, we leverage Liveness information. +limiting the search space. Such a movement is shown in the following graphs. In +addition, we are able to statically free an alloc, if we move it into a +dominator of all of its uses. This simplifies further optimizations (e.g. buffer +fusion) in the future. However, movement of allocations is limited by external +data dependencies (in particular in the case of allocations of dynamically +shaped types). Furthermore, allocations can be moved out of nested regions, if +necessary. In order to move allocations to valid locations with respect to their +uses only, we leverage Liveness information. The following code snippets shows a conditional branch before running the BufferHoisting pass: @@ -165,8 +165,8 @@ The alloc is moved from bb2 to the beginning and it is passed as an argument to bb3. -The following example demonstrates an allocation using dynamically shaped -types. Due to the data dependency of the allocation to %0, we cannot move the +The following example demonstrates an allocation using dynamically shaped types. +Due to the data dependency of the allocation to %0, we cannot move the allocation out of bb2 in this case: ```mlir @@ -216,16 +216,16 @@ ``` The first alloc can be safely freed after the live range of its post-dominator -block (bb3). The alloc in bb1 has an alias %2 in bb3 that also keeps this -buffer alive until the end of bb3. Since we cannot determine the actual -branches that will be taken at runtime, we have to ensure that all buffers are -freed correctly in bb3 regardless of the branches we will take to reach the -exit block. This makes it necessary to introduce a copy for %2, which allows us -to free %alloc0 in bb0 and %alloc1 in bb1. Afterwards, we can continue -processing all aliases of %2 (none in this case) and we can safely free %2 at -the end of the sample program. This sample demonstrates that not all -allocations can be safely freed in their associated post-dominator blocks. -Instead, we have to pay attention to all of their aliases. +block (bb3). The alloc in bb1 has an alias %2 in bb3 that also keeps this buffer +alive until the end of bb3. Since we cannot determine the actual branches that +will be taken at runtime, we have to ensure that all buffers are freed correctly +in bb3 regardless of the branches we will take to reach the exit block. This +makes it necessary to introduce a copy for %2, which allows us to free %alloc0 +in bb0 and %alloc1 in bb1. Afterwards, we can continue processing all aliases of +%2 (none in this case) and we can safely free %2 at the end of the sample +program. This sample demonstrates that not all allocations can be safely freed +in their associated post-dominator blocks. Instead, we have to pay attention to +all of their aliases. Applying the BufferDeallocation pass to the program above yields the following result: @@ -253,8 +253,7 @@ Note that a temporary buffer for %2 was introduced to free all allocations properly. Note further that the unnecessary allocation of %3 can be easily -removed using one of the post-pass transformations or the canonicalization -pass. +removed using one of the post-pass transformations or the canonicalization pass. The presented example also works with dynamically shaped types. @@ -262,9 +261,9 @@ tracked allocations into account. We initialize the general iteration process using all tracked allocations and their associated aliases. As soon as we encounter an alias that is not properly dominated by our allocation, we mark -this alias as _critical_ (needs to be freed and tracked by the internal -fix-point iteration). The following sample demonstrates the presence of -critical and non-critical aliases: +this alias as *critical* (needs to be freed and tracked by the internal +fix-point iteration). The following sample demonstrates the presence of critical +and non-critical aliases: ![nested_branch_example_pre_move](/includes/img/nested_branch_example_pre_move.svg) @@ -345,8 +344,8 @@ operation. Copies for block arguments are handled by analyzing all predecessor blocks. This is primarily done by querying the `BranchOpInterface` of the associated branch terminators that can jump to the current block. Consider the -following example which involves a simple branch and the critical block -argument %2: +following example which involves a simple branch and the critical block argument +%2: ```mlir custom.br ^bb1(..., %0, : ...) @@ -360,24 +359,24 @@ The `BranchOpInterface` allows us to determine the actual values that will be passed to block bb1 and its argument %2 by analyzing its predecessor blocks. Once we have resolved the values %0 and %1 (that are associated with %2 in this -sample), we can introduce a temporary buffer and clone its contents into the -new buffer. Afterwards, we rewire the branch operands to use the newly -allocated buffer instead. However, blocks can have implicitly defined -predecessors by parent ops that implement the `RegionBranchOpInterface`. This -can be the case if this block argument belongs to the entry block of a region. -In this setting, we have to identify all predecessor regions defined by the -parent operation. For every region, we need to get all terminator operations -implementing the `ReturnLike` trait, indicating that they can branch to our -current block. Finally, we can use a similar functionality as described above -to add the temporary copy. This time, we can modify the terminator operands -directly without touching a high-level interface. +sample), we can introduce a temporary buffer and clone its contents into the new +buffer. Afterwards, we rewire the branch operands to use the newly allocated +buffer instead. However, blocks can have implicitly defined predecessors by +parent ops that implement the `RegionBranchOpInterface`. This can be the case if +this block argument belongs to the entry block of a region. In this setting, we +have to identify all predecessor regions defined by the parent operation. For +every region, we need to get all terminator operations implementing the +`ReturnLike` trait, indicating that they can branch to our current block. +Finally, we can use a similar functionality as described above to add the +temporary copy. This time, we can modify the terminator operands directly +without touching a high-level interface. Consider the following inner-region control-flow sample that uses an imaginary -“custom.region_if” operation. It either executes the “then” or “else” region -and always continues to the “join” region. The “custom.region_if_yield” -operation returns a result to the parent operation. This sample demonstrates -the use of the `RegionBranchOpInterface` to determine predecessors in order to -infer the high-level control flow: +“custom.region_if” operation. It either executes the “then” or “else” region and +always continues to the “join” region. The “custom.region_if_yield” operation +returns a result to the parent operation. This sample demonstrates the use of +the `RegionBranchOpInterface` to determine predecessors in order to infer the +high-level control flow: ```mlir func @inner_region_control_flow( @@ -405,7 +404,7 @@ ```mlir func @nested_region_control_flow(%arg0 : index, %arg1 : index) -> memref { - %0 = cmpi "eq", %arg0, %arg1 : index + %0 = arith.cmpi "eq", %arg0, %arg1 : index %1 = memref.alloc(%arg0, %arg0) : memref %2 = scf.if %0 -> (memref) { scf.yield %1 : memref // %2 will be an alias of %1 @@ -420,13 +419,13 @@ ``` In this example, a dealloc is inserted to release the buffer within the else -block since it cannot be accessed by the remainder of the program. Accessing -the `RegionBranchOpInterface`, allows us to infer that %2 is a non-critical -alias of %1 which does not need to be tracked. +block since it cannot be accessed by the remainder of the program. Accessing the +`RegionBranchOpInterface`, allows us to infer that %2 is a non-critical alias of +%1 which does not need to be tracked. ```mlir func @nested_region_control_flow(%arg0: index, %arg1: index) -> memref { - %0 = cmpi "eq", %arg0, %arg1 : index + %0 = arith.cmpi "eq", %arg0, %arg1 : index %1 = memref.alloc(%arg0, %arg0) : memref %2 = scf.if %0 -> (memref) { scf.yield %1 : memref @@ -442,9 +441,9 @@ Analogous to the previous case, we have to detect all terminator operations in all attached regions of “scf.if” that provides a value to its parent operation -(in this sample via scf.yield). Querying the `RegionBranchOpInterface` allows -us to determine the regions that “return” a result to their parent operation. -Like before, we have to update all `ReturnLike` terminators as described above. +(in this sample via scf.yield). Querying the `RegionBranchOpInterface` allows us +to determine the regions that “return” a result to their parent operation. Like +before, we have to update all `ReturnLike` terminators as described above. Reconsider a slightly adapted version of the “custom.region_if” example from above that uses a nested allocation: @@ -468,8 +467,8 @@ Since the allocation %2 happens in a divergent branch and cannot be safely deallocated in a post-dominator, %arg4 will be considered a critical alias. -Furthermore, %arg4 is returned to its parent operation and has an alias %1. -This causes BufferDeallocation to introduce additional copies: +Furthermore, %arg4 is returned to its parent operation and has an alias %1. This +causes BufferDeallocation to introduce additional copies: ```mlir func @inner_region_control_flow_div( @@ -502,9 +501,9 @@ after the last use of the given value. The position can be determined by calculating the common post-dominator of all values using their remaining non-critical aliases. A special-case is the presence of back edges: since such -edges can cause memory leaks when a newly allocated buffer flows back to -another part of the program. In these cases, we need to free the associated -buffer instances from the previous iteration by inserting additional deallocs. +edges can cause memory leaks when a newly allocated buffer flows back to another +part of the program. In these cases, we need to free the associated buffer +instances from the previous iteration by inserting additional deallocs. Consider the following “scf.for” use case containing a nested structured control-flow if: @@ -518,7 +517,7 @@ %res: memref<2xf32>) { %0 = scf.for %i = %lb to %ub step %step iter_args(%iterBuf = %buf) -> memref<2xf32> { - %1 = cmpi "eq", %i, %ub : index + %1 = arith.cmpi "eq", %i, %ub : index %2 = scf.if %1 -> (memref<2xf32>) { %3 = memref.alloc() : memref<2xf32> // makes %2 a critical alias due to a // divergent allocation @@ -534,18 +533,18 @@ } ``` -In this example, the _then_ branch of the nested “scf.if” operation returns a +In this example, the *then* branch of the nested “scf.if” operation returns a newly allocated buffer. Since this allocation happens in the scope of a divergent branch, %2 becomes a -critical alias that needs to be handled. As before, we have to insert -additional copies to eliminate this alias using copies of %3 and %iterBuf. This -guarantees that %2 will be a newly allocated buffer that is returned in each -iteration. However, “returning” %2 to its alias %iterBuf turns %iterBuf into a -critical alias as well. In other words, we have to create a copy of %2 to pass -it to %iterBuf. Since this jump represents a back edge, and %2 will always be a -new buffer, we have to free the buffer from the previous iteration to avoid -memory leaks: +critical alias that needs to be handled. As before, we have to insert additional +copies to eliminate this alias using copies of %3 and %iterBuf. This guarantees +that %2 will be a newly allocated buffer that is returned in each iteration. +However, “returning” %2 to its alias %iterBuf turns %iterBuf into a critical +alias as well. In other words, we have to create a copy of %2 to pass it to +%iterBuf. Since this jump represents a back edge, and %2 will always be a new +buffer, we have to free the buffer from the previous iteration to avoid memory +leaks: ```mlir func @loop_nested_if( @@ -557,7 +556,7 @@ %4 = memref.clone %buf : (memref<2xf32>) -> (memref<2xf32>) %0 = scf.for %i = %lb to %ub step %step iter_args(%iterBuf = %4) -> memref<2xf32> { - %1 = cmpi "eq", %i, %ub : index + %1 = arith.cmpi "eq", %i, %ub : index %2 = scf.if %1 -> (memref<2xf32>) { %3 = memref.alloc() : memref<2xf32> // makes %2 a critical alias use(%3) @@ -612,9 +611,8 @@ If these clones appear with their corresponding dealloc operation within the same block, we can use the canonicalizer to remove these unnecessary operations. Note, that this step needs to take place after the insertion of clones and -deallocs in the buffer deallocation step. The canonicalization inludes both, -the newly created target value from the clone operation and the source -operation. +deallocs in the buffer deallocation step. The canonicalization inludes both, the +newly created target value from the clone operation and the source operation. ## Canonicalization of the Source Buffer of the Clone Operation @@ -653,9 +651,9 @@ operation is also removed. Consider the following example where a generic test operation writes the result -to %temp and then copies %temp to %result. However, these two operations -can be merged into a single step. Canonicalization removes the clone operation -and %temp, and replaces the uses of %temp with %result: +to %temp and then copies %temp to %result. However, these two operations can be +merged into a single step. Canonicalization removes the clone operation and +%temp, and replaces the uses of %temp with %result: ```mlir func @reuseTarget(%arg0: memref<2xf32>, %result: memref<2xf32>){ @@ -666,7 +664,7 @@ indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %temp { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): - %tmp2 = exp %gen2_arg0 : f32 + %tmp2 = math.exp %gen2_arg0 : f32 test.yield %tmp2 : f32 }: memref<2xf32>, memref<2xf32> %result = memref.clone %temp : (memref<2xf32>) -> (memref<2xf32>) @@ -685,7 +683,7 @@ indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %result { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): - %tmp2 = exp %gen2_arg0 : f32 + %tmp2 = math.exp %gen2_arg0 : f32 test.yield %tmp2 : f32 }: memref<2xf32>, memref<2xf32> return @@ -697,6 +695,6 @@ BufferDeallocation introduces additional clones from “memref” dialect (“memref.clone”). Analogous, all deallocations use the “memref” dialect-free operation “memref.dealloc”. The actual copy process is realized using -“test.copy”. Furthermore, buffers are essentially immutable after their -creation in a block. Another limitations are known in the case using -unstructered control flow. +“test.copy”. Furthermore, buffers are essentially immutable after their creation +in a block. Another limitations are known in the case using unstructered control +flow. diff --git a/mlir/docs/Bufferization.md b/mlir/docs/Bufferization.md --- a/mlir/docs/Bufferization.md +++ b/mlir/docs/Bufferization.md @@ -6,8 +6,8 @@ Bufferization in MLIR is the process of converting the `tensor` type to the `memref` type. MLIR provides a composable system that allows dialects to -systematically bufferize a program. This system is a simple application -of MLIR's [dialect conversion](DialectConversion.md) infrastructure. The bulk of +systematically bufferize a program. This system is a simple application of +MLIR's [dialect conversion](DialectConversion.md) infrastructure. The bulk of the code related to bufferization is a set of ordinary `ConversionPattern`'s that dialect authors write for converting ops that operate on `tensor`'s to ops that operate on `memref`'s. A set of conventions and best practices are followed @@ -34,11 +34,12 @@ w.r.t. control flow. Thus, a realistic compilation pipeline will usually consist of: -1. Bufferization -1. Buffer optimizations such as `buffer-hoisting`, `buffer-loop-hoisting`, and - `promote-buffers-to-stack`, which do optimizations that are only exposed - after bufferization. -1. Finally, running the [buffer deallocation](BufferDeallocationInternals.md) pass. +1. Bufferization +1. Buffer optimizations such as `buffer-hoisting`, `buffer-loop-hoisting`, and + `promote-buffers-to-stack`, which do optimizations that are only exposed + after bufferization. +1. Finally, running the [buffer deallocation](BufferDeallocationInternals.md) + pass. After buffer deallocation has been completed, the program will be quite difficult to transform due to the presence of the deallocation ops. Thus, other @@ -46,8 +47,8 @@ ## General structure of the bufferization process -Bufferization consists of running multiple _partial_ bufferization passes, -followed by one _finalizing_ bufferization pass. +Bufferization consists of running multiple *partial* bufferization passes, +followed by one *finalizing* bufferization pass. There is typically one partial bufferization pass per dialect (though other subdivisions are possible). For example, for a dialect `X` there will typically @@ -56,7 +57,7 @@ in the program are incrementally bufferized. Partial bufferization passes create programs where only some ops have been -bufferized. These passes will create _materializations_ (also sometimes called +bufferized. These passes will create *materializations* (also sometimes called "casts") that convert between the `tensor` and `memref` type, which allows bridging between ops that have been bufferized and ops that have not yet been bufferized. @@ -180,8 +181,8 @@ ``` The pass has all the hallmarks of a dialect conversion pass that does type -conversions: a `TypeConverter`, a `RewritePatternSet`, and a -`ConversionTarget`, and a call to `applyPartialConversion`. Note that a function +conversions: a `TypeConverter`, a `RewritePatternSet`, and a `ConversionTarget`, +and a call to `applyPartialConversion`. Note that a function `populateTensorBufferizePatterns` is separated, so that power users can use the patterns independently, if necessary (such as to combine multiple sets of conversion patterns into a single conversion call, for performance). @@ -190,55 +191,59 @@ `BufferizeTypeConverter`, which comes pre-loaded with the necessary conversions and materializations between `tensor` and `memref`. -In this case, the `MemRefOpsDialect` is marked as legal, so the `tensor_load` -and `buffer_cast` ops, which are inserted automatically by the dialect -conversion framework as materializations, are legal. There is a helper -`populateBufferizeMaterializationLegality` +In this case, the `MemRefOpsDialect` is marked as legal, so the +`memref.tensor_load` and `memref.buffer_cast` ops, which are inserted +automatically by the dialect conversion framework as materializations, are +legal. There is a helper `populateBufferizeMaterializationLegality` ([code](https://github.com/llvm/llvm-project/blob/a0b65a7bcd6065688189b3d678c42ed6af9603db/mlir/include/mlir/Transforms/Bufferize.h#L53)) which helps with this in general. ### Other partial bufferization examples -- `linalg-bufferize` - ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp#L1), - [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Linalg/bufferize.mlir#L1)) - - - Bufferizes the `linalg` dialect. - - This is an example of how to simultaneously bufferize all the ops that - satisfy a certain OpInterface with a single pattern. Specifically, - `BufferizeAnyLinalgOp` - ([code](https://github.com/llvm/llvm-project/blob/daaaed6bb89044ac58a23f1bb1ccdd12342a5a58/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp#L170)) - bufferizes any ops that implements the `LinalgOp` interface. - -- `scf-bufferize` - ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp#L1), - [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/SCF/bufferize.mlir#L1)) - - - Bufferizes ops from the `scf` dialect. - - This is an example of how to bufferize ops that implement - `RegionBranchOpInterface` (that is, they use regions to represent control - flow). - - The bulk of the work is done by - `lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp` - ([code](https://github.com/llvm/llvm-project/blob/daaaed6bb89044ac58a23f1bb1ccdd12342a5a58/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp#L1)), - which is well-commented and covers how to correctly convert ops that contain - regions. - -- `func-bufferize` - ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp#L1), - [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Standard/func-bufferize.mlir#L1)) - - - Bufferizes `func`, `call`, and `BranchOpInterface` ops. - - This is an example of how to bufferize ops that have multi-block regions. - - This is an example of a pass that is not split along dialect subdivisions. - -- `tensor-constant-bufferize` - ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp#L1), - [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir#L1)) - - Bufferizes only `std.constant` ops of `tensor` type. - - This is an example of setting up the legality so that only a subset of - `std.constant` ops get bufferized. - - This is an example of a pass that is not split along dialect subdivisions. +- `linalg-bufferize` + ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp#L1), + [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Linalg/bufferize.mlir#L1)) + + - Bufferizes the `linalg` dialect. + - This is an example of how to simultaneously bufferize all the ops that + satisfy a certain OpInterface with a single pattern. Specifically, + `BufferizeAnyLinalgOp` + ([code](https://github.com/llvm/llvm-project/blob/daaaed6bb89044ac58a23f1bb1ccdd12342a5a58/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp#L170)) + bufferizes any ops that implements the `LinalgOp` interface. + +- `scf-bufferize` + ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp#L1), + [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/SCF/bufferize.mlir#L1)) + + - Bufferizes ops from the `scf` dialect. + - This is an example of how to bufferize ops that implement + `RegionBranchOpInterface` (that is, they use regions to represent + control flow). + - The bulk of the work is done by + `lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp` + ([code](https://github.com/llvm/llvm-project/blob/daaaed6bb89044ac58a23f1bb1ccdd12342a5a58/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp#L1)), + which is well-commented and covers how to correctly convert ops that + contain regions. + +- `func-bufferize` + ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp#L1), + [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Standard/func-bufferize.mlir#L1)) + + - Bufferizes `func`, `call`, and `BranchOpInterface` ops. + - This is an example of how to bufferize ops that have multi-block + regions. + - This is an example of a pass that is not split along dialect + subdivisions. + +- `tensor-constant-bufferize` + ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp#L1), + [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir#L1)) + + - Bufferizes only `arith.constant` ops of `tensor` type. + - This is an example of setting up the legality so that only a subset of + `std.constant` ops get bufferized. + - This is an example of a pass that is not split along dialect + subdivisions. ## How to write a finalizing bufferization pass @@ -246,10 +251,10 @@ from the program. The easiest way to write a finalizing bufferize pass is to not write one at all! -MLIR provides a pass `finalizing-bufferize` which eliminates the `tensor_load` / -`buffer_cast` materialization ops inserted by partial bufferization passes -and emits an error if that is not sufficient to remove all tensors from the -program. +MLIR provides a pass `finalizing-bufferize` which eliminates the +`memref.tensor_load` / `memref.buffer_cast` materialization ops inserted by +partial bufferization passes and emits an error if that is not sufficient to +remove all tensors from the program. This pass is sufficient when partial bufferization passes have bufferized all the ops in the program, leaving behind only the materializations. When possible, @@ -260,18 +265,17 @@ unbufferized op. However, before the current bufferization infrastructure was put in place, -bufferization could only be done as a single finalizing bufferization -mega-pass that used the `populate*BufferizePatterns` functions from multiple -dialects to simultaneously bufferize everything at once. Thus, one might see -code in downstream projects structured this way. This structure is not -recommended in new code. A helper, -`populateEliminateBufferizeMaterializationsPatterns` +bufferization could only be done as a single finalizing bufferization mega-pass +that used the `populate*BufferizePatterns` functions from multiple dialects to +simultaneously bufferize everything at once. Thus, one might see code in +downstream projects structured this way. This structure is not recommended in +new code. A helper, `populateEliminateBufferizeMaterializationsPatterns` ([code](https://github.com/llvm/llvm-project/blob/a0b65a7bcd6065688189b3d678c42ed6af9603db/mlir/include/mlir/Transforms/Bufferize.h#L58)) -is available for such passes to provide patterns that eliminate `tensor_load` -and `buffer_cast`. +is available for such passes to provide patterns that eliminate +`memref.tensor_load` and `memref.buffer_cast`. ## Changes since [the talk](#the-talk) -- `func-bufferize` was changed to be a partial conversion pass, and there is a - new `finalizing-bufferize` which serves as a general finalizing bufferization - pass. +- `func-bufferize` was changed to be a partial conversion pass, and there is a + new `finalizing-bufferize` which serves as a general finalizing + bufferization pass. diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -68,8 +68,8 @@ A declarative rewrite rule contains two main components: -* A _source pattern_, which is used for matching a DAG of operations. -* One or more _result patterns_, which are used for generating DAGs of +* A *source pattern*, which is used for matching a DAG of operations. +* One or more *result patterns*, which are used for generating DAGs of operations to replace the matched DAG of operations. We allow multiple result patterns to support @@ -380,8 +380,8 @@ ##### `NativeCodeCall` placeholders In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N` and `$N...`. -The former is called _special placeholder_, while the latter is called -_positional placeholder_ and _positional range placeholder_. +The former is called *special placeholder*, while the latter is called +*positional placeholder* and *positional range placeholder*. `NativeCodeCall` right now only supports three special placeholders: `$_builder`, `$_loc`, and `$_self`: @@ -405,15 +405,16 @@ ``` In the above, `$_self` is substituted by the defining operation of the first -operand of OneAttrOp. Note that we don't support binding name to `NativeCodeCall` -in the source pattern. To carry some return values from a helper function, put the -names (constraint is optional) in the parameter list and they will be bound to -the variables with correspoding type. Then these names must be either passed by -reference or pointer to the variable used as argument so that the matched value -can be returned. In the same example, `$val` will be bound to a variable with -`Attribute` type (as `I32Attr`) and the type of the second argument in `Foo()` -could be `Attribute&` or `Attribute*`. Names with attribute constraints will be -captured as `Attribute`s while everything else will be treated as `Value`s. +operand of OneAttrOp. Note that we don't support binding name to +`NativeCodeCall` in the source pattern. To carry some return values from a +helper function, put the names (constraint is optional) in the parameter list +and they will be bound to the variables with correspoding type. Then these names +must be either passed by reference or pointer to the variable used as argument +so that the matched value can be returned. In the same example, `$val` will be +bound to a variable with `Attribute` type (as `I32Attr`) and the type of the +second argument in `Foo()` could be `Attribute&` or `Attribute*`. Names with +attribute constraints will be captured as `Attribute`s while everything else +will be treated as `Value`s. Positional placeholders will be substituted by the `dag` object parameters at the `NativeCodeCall` use site. For example, if we define `SomeCall : @@ -445,9 +446,9 @@ The correct number of returned value specified in NativeCodeCall is important. It will be used to verify the consistency of the number of return values. Additionally, `mlir-tblgen` will try to capture the return values of -`NativeCodeCall` in the generated code so that it will trigger a later compilation -error if a `NativeCodeCall` that doesn't return any result isn't labeled with 0 -returns. +`NativeCodeCall` in the generated code so that it will trigger a later +compilation error if a `NativeCodeCall` that doesn't return any result isn't +labeled with 0 returns. ##### Customizing entire op building @@ -471,7 +472,7 @@ ### Supporting auxiliary ops A declarative rewrite rule supports multiple result patterns. One of the -purposes is to allow generating _auxiliary ops_. Auxiliary ops are operations +purposes is to allow generating *auxiliary ops*. Auxiliary ops are operations used for building the replacement ops; but they are not directly used for replacement themselves. @@ -486,17 +487,17 @@ want to allocate memory and store some computation (in pseudocode): ```mlir -%dst = addi %lhs, %rhs +%dst = arith.addi %lhs, %rhs ``` into ```mlir %shape = shape %lhs -%mem = alloc %shape -%sum = addi %lhs, %rhs -store %mem, %sum -%dst = load %mem +%mem = memref.alloc %shape +%sum = arith.addi %lhs, %rhs +memref.store %mem, %sum +%dst = memref.load %mem ``` We cannot fit in with just one result pattern given `store` does not return a @@ -610,10 +611,10 @@ Before going into details on variadic op support, we need to define a few terms regarding an op's values. -* _Value_: either an operand or a result -* _Declared operand/result/value_: an operand/result/value statically declared +* *Value*: either an operand or a result +* *Declared operand/result/value*: an operand/result/value statically declared in ODS of the op -* _Actual operand/result/value_: an operand/result/value of an op instance at +* *Actual operand/result/value*: an operand/result/value of an op instance at runtime The above terms are needed because ops can have multiple results, and some of @@ -754,12 +755,12 @@ The `returnType` directive must be used as a trailing argument to a node describing a replacement op. The directive comes in three forms: -* `(returnType $value)`: copy the type of the operand or result bound to - `value`. -* `(returnType "$_builder.getI32Type()")`: a string literal embedding C++. The - embedded snippet is expected to return a `Type` or a `TypeRange`. -* `(returnType (NativeCodeCall<"myFunc($0)"> $value))`: a DAG node with a native - code call that can be passed any bound variables arguments. +* `(returnType $value)`: copy the type of the operand or result bound to + `value`. +* `(returnType "$_builder.getI32Type()")`: a string literal embedding C++. The + embedded snippet is expected to return a `Type` or a `TypeRange`. +* `(returnType (NativeCodeCall<"myFunc($0)"> $value))`: a DAG node with a + native code call that can be passed any bound variables arguments. Specify multiple return types with a mix of any of the above. Example: diff --git a/mlir/docs/Diagnostics.md b/mlir/docs/Diagnostics.md --- a/mlir/docs/Diagnostics.md +++ b/mlir/docs/Diagnostics.md @@ -301,7 +301,7 @@ // Expect an error on an adjacent line. func @foo(%a : f32) { // expected-error@+1 {{unknown comparison predicate "foo"}} - %result = cmpf "foo", %a, %a : f32 + %result = arith.cmpf "foo", %a, %a : f32 return } diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -66,7 +66,7 @@ - This action signals that only some instances of a given operation are legal. This allows for defining fine-tune constraints, e.g. saying that - `addi` is only legal when operating on 32-bit integers. + `arith.addi` is only legal when operating on 32-bit integers. * Illegal diff --git a/mlir/docs/Dialects/Affine.md b/mlir/docs/Dialects/Affine.md --- a/mlir/docs/Dialects/Affine.md +++ b/mlir/docs/Dialects/Affine.md @@ -13,8 +13,8 @@ ### Dimensions and Symbols Dimensions and symbols are the two kinds of identifiers that can appear in the -polyhedral structures, and are always of [`index`](Builtin.md/#indextype) -type. Dimensions are declared in parentheses and symbols are declared in square +polyhedral structures, and are always of [`index`](Builtin.md/#indextype) type. +Dimensions are declared in parentheses and symbols are declared in square brackets. Examples: @@ -54,36 +54,34 @@ ```mlir #affine_map2to3 = affine_map<(d0, d1)[s0] -> (d0, d1 + s0, d1 - s0)> // Binds %N to the s0 symbol in affine_map2to3. -%x = alloc()[%N] : memref<40x50xf32, #affine_map2to3> +%x = memref.alloc()[%N] : memref<40x50xf32, #affine_map2to3> ``` ### Restrictions on Dimensions and Symbols The affine dialect imposes certain restrictions on dimension and symbolic identifiers to enable powerful analysis and transformation. An SSA value's use -can be bound to a symbolic identifier if that SSA value is either -1. a region argument for an op with trait `AffineScope` (eg. `FuncOp`), -2. a value defined at the top level of an `AffineScope` op (i.e., immediately -enclosed by the latter), -3. a value that dominates the `AffineScope` op enclosing the value's use, -4. the result of a [`constant` operation](Standard.md/#stdconstant-constantop), -5. the result of an [`affine.apply` -operation](#affineapply-affineapplyop) that recursively takes as arguments any valid -symbolic identifiers, or -6. the result of a [`dim` operation](MemRef.md/#memrefdim-mlirmemrefdimop) on either a -memref that is an argument to a `AffineScope` op or a memref where the -corresponding dimension is either static or a dynamic one in turn bound to a -valid symbol. +can be bound to a symbolic identifier if that SSA value is either 1. a region +argument for an op with trait `AffineScope` (eg. `FuncOp`), 2. a value defined +at the top level of an `AffineScope` op (i.e., immediately enclosed by the +latter), 3. a value that dominates the `AffineScope` op enclosing the value's +use, 4. the result of a +[`constant` operation](Standard.md/#stdconstant-constantop), 5. the result of an +[`affine.apply` operation](#affineapply-affineapplyop) that recursively takes as +arguments any valid symbolic identifiers, or 6. the result of a +[`dim` operation](MemRef.md/#memrefdim-mlirmemrefdimop) on either a memref that +is an argument to a `AffineScope` op or a memref where the corresponding +dimension is either static or a dynamic one in turn bound to a valid symbol. *Note:* if the use of an SSA value is not contained in any op with the `AffineScope` trait, only the rules 4-6 can be applied. Note that as a result of rule (3) above, symbol validity is sensitive to the -location of the SSA use. Dimensions may be bound not only to anything that a +location of the SSA use. Dimensions may be bound not only to anything that a symbol is bound to, but also to induction variables of enclosing [`affine.for`](#affinefor-affineforop) and -[`affine.parallel`](#affineparallel-affineparallelop) operations, and the result of an -[`affine.apply` operation](#affineapply-affineapplyop) (which recursively may use -other dimensions and symbols). +[`affine.parallel`](#affineparallel-affineparallelop) operations, and the result +of an [`affine.apply` operation](#affineapply-affineapplyop) (which recursively +may use other dimensions and symbols). ### Affine Expressions @@ -119,24 +117,24 @@ ceildiv, and (4) addition and subtraction. All of these operators associate from left to right. -A _multidimensional affine expression_ is a comma separated list of +A *multidimensional affine expression* is a comma separated list of one-dimensional affine expressions, with the entire list enclosed in parentheses. **Context:** An affine function, informally, is a linear function plus a constant. More formally, a function f defined on a vector $\vec{v} \in -\mathbb{Z}^n$ is a multidimensional affine function of $\vec{v}$ if -$f(\vec{v})$ can be expressed in the form $M \vec{v} + \vec{c}$ where $M$ -is a constant matrix from $\mathbb{Z}^{m \times n}$ and $\vec{c}$ is a -constant vector from $\mathbb{Z}$. $m$ is the dimensionality of such an -affine function. MLIR further extends the definition of an affine function to -allow 'floordiv', 'ceildiv', and 'mod' with respect to positive integer -constants. Such extensions to affine functions have often been referred to as -quasi-affine functions by the polyhedral compiler community. MLIR uses the term -'affine map' to refer to these multidimensional quasi-affine functions. As -examples, $(i+j+1, j)$, $(i \mod 2, j+i)$, $(j, i/4, i \mod 4)$, $(2i+1, -j)$ are two-dimensional affine functions of $(i, j)$, but $(i \cdot j, -i^2)$, $(i \mod j, i/j)$ are not affine functions of $(i, j)$. +\mathbb{Z}^n$ is a multidimensional affine function of $\vec{v}$ if $f(\vec{v})$ +can be expressed in the form $M \vec{v} + \vec{c}$ where $M$ is a constant +matrix from $\mathbb{Z}^{m \times n}$ and $\vec{c}$ is a constant vector from +$\mathbb{Z}$. $m$ is the dimensionality of such an affine function. MLIR further +extends the definition of an affine function to allow 'floordiv', 'ceildiv', and +'mod' with respect to positive integer constants. Such extensions to affine +functions have often been referred to as quasi-affine functions by the +polyhedral compiler community. MLIR uses the term 'affine map' to refer to these +multidimensional quasi-affine functions. As examples, $(i+j+1, j)$, $(i \mod 2, +j+i)$, $(j, i/4, i \mod 4)$, $(2i+1, j)$ are two-dimensional affine functions of +$(i, j)$, but $(i \cdot j, i^2)$, $(i \mod j, i/j)$ are not affine functions of +$(i, j)$. ### Affine Maps @@ -157,9 +155,9 @@ combining the indices and symbols. Affine maps distinguish between [indices and symbols](#dimensions-and-symbols) because indices are inputs to the affine map when the map is called (through an operation such as -[affine.apply](#affineapply-affineapplyop)), whereas symbols are bound when -the map is established (e.g. when a memref is formed, establishing a -memory [layout map](Builtin.md/#layout-map)). +[affine.apply](#affineapply-affineapplyop)), whereas symbols are bound when the +map is established (e.g. when a memref is formed, establishing a memory +[layout map](Builtin.md/#layout-map)). Affine maps are used for various core structures in MLIR. The restrictions we impose on their form allows powerful analysis and transformation, while keeping @@ -192,10 +190,10 @@ // Use an affine mapping definition in an alloc operation, binding the // SSA value %N to the symbol s0. -%a = alloc()[%N] : memref<4x4xf32, #affine_map42> +%a = memref.alloc()[%N] : memref<4x4xf32, #affine_map42> // Same thing with an inline affine mapping definition. -%b = alloc()[%N] : memref<4x4xf32, affine_map<(d0, d1)[s0] -> (d0, d0 + d1 + s0 floordiv 2)>> +%b = memref.alloc()[%N] : memref<4x4xf32, affine_map<(d0, d1)[s0] -> (d0, d0 + d1 + s0 floordiv 2)>> ``` ### Semi-affine maps @@ -378,23 +376,21 @@ The `affine.dma_start` op starts a non-blocking DMA operation that transfers data from a source memref to a destination memref. The source and destination memref need not be of the same dimensionality, but need to have the same -elemental type. The operands include the source and destination memref's -each followed by its indices, size of the data transfer in terms of the -number of elements (of the elemental type of the memref), a tag memref with -its indices, and optionally at the end, a stride and a -number_of_elements_per_stride arguments. The tag location is used by an -AffineDmaWaitOp to check for completion. The indices of the source memref, -destination memref, and the tag memref have the same restrictions as any -affine.load/store. In particular, index for each memref dimension must be an -affine expression of loop induction variables and symbols. -The optional stride arguments should be of 'index' type, and specify a -stride for the slower memory space (memory space with a lower memory space -id), transferring chunks of number_of_elements_per_stride every stride until -%num_elements are transferred. Either both or no stride arguments should be -specified. The value of 'num_elements' must be a multiple of +elemental type. The operands include the source and destination memref's each +followed by its indices, size of the data transfer in terms of the number of +elements (of the elemental type of the memref), a tag memref with its indices, +and optionally at the end, a stride and a number_of_elements_per_stride +arguments. The tag location is used by an AffineDmaWaitOp to check for +completion. The indices of the source memref, destination memref, and the tag +memref have the same restrictions as any affine.load/store. In particular, index +for each memref dimension must be an affine expression of loop induction +variables and symbols. The optional stride arguments should be of 'index' type, +and specify a stride for the slower memory space (memory space with a lower +memory space id), transferring chunks of number_of_elements_per_stride every +stride until %num_elements are transferred. Either both or no stride arguments +should be specified. The value of 'num_elements' must be a multiple of 'number_of_elements_per_stride'. - Example: ```mlir @@ -403,8 +399,8 @@ space 1 at indices [%k + 7, %l], would be specified as follows: %num_elements = constant 256 - %idx = constant 0 : index - %tag = alloc() : memref<1xi32, 4> + %idx = arith.constant 0 : index + %tag = memref.alloc() : memref<1xi32, 4> affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx], %num_elements : memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2> @@ -426,10 +422,10 @@ ``` The `affine.dma_start` op blocks until the completion of a DMA operation -associated with the tag element '%tag[%index]'. %tag is a memref, and %index -has to be an index with the same restrictions as any load/store index. -In particular, index for each memref dimension must be an affine expression of -loop induction variables and symbols. %num_elements is the number of elements +associated with the tag element '%tag[%index]'. %tag is a memref, and %index has +to be an index with the same restrictions as any load/store index. In +particular, index for each memref dimension must be an affine expression of loop +induction variables and symbols. %num_elements is the number of elements associated with the DMA operation. For example: Example: diff --git a/mlir/docs/Dialects/Linalg/_index.md b/mlir/docs/Dialects/Linalg/_index.md --- a/mlir/docs/Dialects/Linalg/_index.md +++ b/mlir/docs/Dialects/Linalg/_index.md @@ -125,14 +125,14 @@ #map0 = affine_map<(d0) -> (d0 * 2 + 1)> func @example(%arg0: memref, %arg1: memref, #map0>) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %0 = dim %arg0, %c0 : memref + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.dim %arg0, %c0 : memref scf.for %arg2 = %c0 to %0 step %c1 { - %1 = load %arg0[%arg2] : memref - %2 = load %arg1[%arg2] : memref, #map0> + %1 = memref.load %arg0[%arg2] : memref + %2 = memref.load %arg1[%arg2] : memref, #map0> %3 = "some_compute"(%1, %2) : (f32, vector<4xf32>) -> vector<4xf32> - store %3, %arg1[%arg2] : memref, #map0> + memref.store %3, %arg1[%arg2] : memref, #map0> } return } @@ -207,16 +207,16 @@ #map0 = affine_map<(d0, d1) -> (d0 * 2 + d1 * 2)> func @example(%arg0: memref<8x?xf32, #map0>, %arg1: memref>) { - %c8 = constant 8 : index - %c0 = constant 0 : index - %c1 = constant 1 : index - %0 = dim %arg0, %c1 : memref<8x?xf32, #map0> + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.dim %arg0, %c1 : memref<8x?xf32, #map0> scf.for %arg2 = %c0 to %0 step %c1 { scf.for %arg3 = %c0 to %c8 step %c1 { - %1 = load %arg0[%arg3, %arg2] : memref<8x?xf32, #map0> - %2 = load %arg1[%arg3] : memref> + %1 = memref.load %arg0[%arg3, %arg2] : memref<8x?xf32, #map0> + %2 = memref.load %arg1[%arg3] : memref> %3 = "some_compute"(%1, %2) : (f32, vector<4xf32>) -> vector<4xf32> - store %3, %arg1[%arg3] : memref> + memref.store %3, %arg1[%arg3] : memref> } } return @@ -314,7 +314,7 @@ ins(%A, %B: memref, memref) outs(%C: memref) { ^bb0(%a: f32, %b: f32, %c: f32): - %d = addf %a, %b : f32 + %d = arith.addf %a, %b : f32 linalg.yield %d : f32 } @@ -330,16 +330,16 @@ ```mlir func @example(%arg0: memref, %arg1: memref, %arg2: memref) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %0 = dim %arg0, %c0 : memref - %1 = dim %arg0, %c1 : memref + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.dim %arg0, %c0 : memref + %1 = memref.dim %arg0, %c1 : memref scf.for %arg3 = %c0 to %0 step %c1 { scf.for %arg4 = %c0 to %1 step %c1 { - %2 = load %arg0[%arg3, %arg4] : memref - %3 = load %arg1[%arg3, %arg4] : memref - %4 = addf %2, %3 : f32 - store %4, %arg2[%arg3, %arg4] : memref + %2 = memref.load %arg0[%arg3, %arg4] : memref + %3 = memref.load %arg1[%arg3, %arg4] : memref + %4 = arith.addf %2, %3 : f32 + memref.store %4, %arg2[%arg3, %arg4] : memref } } return @@ -387,7 +387,7 @@ ins(%A, %B: memref, memref) outs(%C: memref) { ^bb0(%a: f32, %b: f32, %c: f32): - %d = addf %a, %b : f32 + %d = arith.addf %a, %b : f32 linalg.yield %d : f32 } return @@ -518,7 +518,7 @@ ``` * `memref.view`, -* `std.subview`, +* `memref.subview`, * `memref.transpose`. * `linalg.range`, * `linalg.slice`, diff --git a/mlir/docs/Dialects/MemRef.md b/mlir/docs/Dialects/MemRef.md --- a/mlir/docs/Dialects/MemRef.md +++ b/mlir/docs/Dialects/MemRef.md @@ -16,7 +16,7 @@ Syntax: ``` -operation ::= `dma_start` ssa-use`[`ssa-use-list`]` `,` +operation ::= `memref.dma_start` ssa-use`[`ssa-use-list`]` `,` ssa-use`[`ssa-use-list`]` `,` ssa-use `,` ssa-use`[`ssa-use-list`]` (`,` ssa-use `,` ssa-use)? `:` memref-type `,` memref-type `,` memref-type @@ -39,17 +39,17 @@ destination memref need not be of the same dimensionality, but need to have the same elemental type. -For example, a `dma_start` operation that transfers 32 vector elements from a -memref `%src` at location `[%i, %j]` to memref `%dst` at `[%k, %l]` would be -specified as shown below. +For example, a `memref.dma_start` operation that transfers 32 vector elements +from a memref `%src` at location `[%i, %j]` to memref `%dst` at `[%k, %l]` would +be specified as shown below. Example: ```mlir -%size = constant 32 : index -%tag = alloc() : memref<1 x i32, affine_map<(d0) -> (d0)>, 4> -%idx = constant 0 : index -dma_start %src[%i, %j], %dst[%k, %l], %size, %tag[%idx] : +%size = arith.constant 32 : index +%tag = memref.alloc() : memref<1 x i32, affine_map<(d0) -> (d0)>, 4> +%idx = arith.constant 0 : index +memref.dma_start %src[%i, %j], %dst[%k, %l], %size, %tag[%idx] : memref<40 x 8 x vector<16xf32>, affine_map<(d0, d1) -> (d0, d1)>, 0>, memref<2 x 4 x vector<16xf32>, affine_map<(d0, d1) -> (d0, d1)>, 2>, memref<1 x i32>, affine_map<(d0) -> (d0)>, 4> @@ -60,7 +60,7 @@ Syntax: ``` -operation ::= `dma_wait` ssa-use`[`ssa-use-list`]` `,` ssa-use `:` memref-type +operation ::= `memref.dma_wait` ssa-use`[`ssa-use-list`]` `,` ssa-use `:` memref-type ``` Blocks until the completion of a DMA operation associated with the tag element @@ -72,5 +72,5 @@ Example: ```mlir -dma_wait %tag[%idx], %size : memref<1 x i32, affine_map<(d0) -> (d0)>, 4> +memref.dma_wait %tag[%idx], %size : memref<1 x i32, affine_map<(d0) -> (d0)>, 4> ``` diff --git a/mlir/docs/Dialects/Vector.md b/mlir/docs/Dialects/Vector.md --- a/mlir/docs/Dialects/Vector.md +++ b/mlir/docs/Dialects/Vector.md @@ -3,26 +3,27 @@ [TOC] MLIR supports multi-dimensional `vector` types and custom operations on those -types. A generic, retargetable, higher-order ``vector`` type (`n-D` with `n > -1`) is a structured type, that carries semantic information useful for -transformations. This document discusses retargetable abstractions that exist -in MLIR today and operate on ssa-values of type `vector` along with pattern +types. A generic, retargetable, higher-order `vector` type (`n-D` with `n > 1`) +is a structured type, that carries semantic information useful for +transformations. This document discusses retargetable abstractions that exist in +MLIR today and operate on ssa-values of type `vector` along with pattern rewrites and lowerings that enable targeting specific instructions on concrete targets. These abstractions serve to separate concerns between operations on -`memref` (a.k.a buffers) and operations on ``vector`` values. This is not a -new proposal but rather a textual documentation of existing MLIR components -along with a rationale. +`memref` (a.k.a buffers) and operations on `vector` values. This is not a new +proposal but rather a textual documentation of existing MLIR components along +with a rationale. ## Positioning in the Codegen Infrastructure -The following diagram, recently presented with the [StructuredOps -abstractions](https://drive.google.com/corp/drive/u/0/folders/1sRAsgsd8Bvpm_IxREmZf2agsGU2KvrK-), + +The following diagram, recently presented with the +[StructuredOps abstractions](https://drive.google.com/corp/drive/u/0/folders/1sRAsgsd8Bvpm_IxREmZf2agsGU2KvrK-), captures the current codegen paths implemented in MLIR in the various existing lowering paths. ![](https://user-images.githubusercontent.com/10148468/71177417-f78e4d80-2239-11ea-92ef-700f42ea503f.png) -The following diagram seeks to isolate `vector` dialects from the complexity -of the codegen paths and focus on the payload-carrying ops that operate on std -and `vector` types. This diagram is not to be taken as set in stone and +The following diagram seeks to isolate `vector` dialects from the complexity of +the codegen paths and focus on the payload-carrying ops that operate on std and +`vector` types. This diagram is not to be taken as set in stone and representative of what exists today but rather illustrates the layering of abstractions in MLIR. @@ -31,164 +32,165 @@ This  separates concerns related to (a) defining efficient operations on `vector` types from (b) program analyses + transformations on `memref`, loops and other types of structured ops (be they `HLO`, `LHLO`, `Linalg` or other ). -Looking a bit forward in time, we can put a stake in the ground and venture -that the higher level of `vector`-level primitives we build and target from -codegen (or some user/language level), the simpler our task will be, the more -complex patterns can be expressed and the better performance will be. +Looking a bit forward in time, we can put a stake in the ground and venture that +the higher level of `vector`-level primitives we build and target from codegen +(or some user/language level), the simpler our task will be, the more complex +patterns can be expressed and the better performance will be. ## Components of a Generic Retargetable Vector-Level Dialect -The existing MLIR `vector`-level dialects are related to the following -bottom-up abstractions: - -1. Representation in `LLVMIR` via data structures, instructions and -intrinsics. This is referred to as the `LLVM` level. -2. Set of machine-specific operations and types that are built to translate -almost 1-1 with the HW ISA. This is referred to as the Hardware Vector level; -a.k.a `HWV`. For instance, we have (a) the `NVVM` dialect (for `CUDA`) with -tensor core ops, (b) accelerator-specific dialects (internal), a potential -(future) `CPU` dialect to capture `LLVM` intrinsics more closely and other -dialects for specific hardware. Ideally this should be auto-generated as much -as possible from the `LLVM` level. -3. Set of virtual, machine-agnostic, operations that are informed by costs at -the `HWV`-level. This is referred to as the Virtual Vector level; a.k.a -`VV`. This is the level that higher-level abstractions (codegen, automatic -vectorization, potential vector language, ...) targets. + +The existing MLIR `vector`-level dialects are related to the following bottom-up +abstractions: + +1. Representation in `LLVMIR` via data structures, instructions and intrinsics. + This is referred to as the `LLVM` level. +2. Set of machine-specific operations and types that are built to translate + almost 1-1 with the HW ISA. This is referred to as the Hardware Vector + level; a.k.a `HWV`. For instance, we have (a) the `NVVM` dialect (for + `CUDA`) with tensor core ops, (b) accelerator-specific dialects (internal), + a potential (future) `CPU` dialect to capture `LLVM` intrinsics more closely + and other dialects for specific hardware. Ideally this should be + auto-generated as much as possible from the `LLVM` level. +3. Set of virtual, machine-agnostic, operations that are informed by costs at + the `HWV`-level. This is referred to as the Virtual Vector level; a.k.a + `VV`. This is the level that higher-level abstractions (codegen, automatic + vectorization, potential vector language, ...) targets. The existing generic, retargetable, `vector`-level dialect is related to the following top-down rewrites and conversions: -1. MLIR Rewrite Patterns applied by the MLIR `PatternRewrite` infrastructure -to progressively lower to implementations that match closer and closer to the -`HWV`. Some patterns are "in-dialect" `VV -> VV` and some are conversions `VV --> HWV`. -2. `Virtual Vector -> Hardware Vector` lowering is specified as a set of MLIR -lowering patterns that are specified manually for now. -3. `Hardware Vector -> LLVM` lowering is a mechanical process that is written -manually at the moment and that should be automated, following the `LLVM -> -Hardware Vector` ops generation as closely as possible. +1. MLIR Rewrite Patterns applied by the MLIR `PatternRewrite` infrastructure to + progressively lower to implementations that match closer and closer to the + `HWV`. Some patterns are "in-dialect" `VV -> VV` and some are conversions + `VV -> HWV`. +2. `Virtual Vector -> Hardware Vector` lowering is specified as a set of MLIR + lowering patterns that are specified manually for now. +3. `Hardware Vector -> LLVM` lowering is a mechanical process that is written + manually at the moment and that should be automated, following the `LLVM -> + Hardware Vector` ops generation as closely as possible. ## Short Description of the Existing Infrastructure ### LLVM level -On CPU, the `n-D` `vector` type currently lowers to -`!llvm>`. More concretely, `vector<4x8x128xf32>` lowers to -`!llvm<[4 x [ 8 x [ 128 x float ]]]>`. -There are tradeoffs involved related to how one can access subvectors and how -one uses `llvm.extractelement`, `llvm.insertelement` and -`llvm.shufflevector`. A [deeper dive section](#DeeperDive) discusses the -current lowering choices and tradeoffs. + +On CPU, the `n-D` `vector` type currently lowers to `!llvm>`. More +concretely, `vector<4x8x128xf32>` lowers to `!llvm<[4 x [ 8 x [ 128 x float +]]]>`. There are tradeoffs involved related to how one can access subvectors and +how one uses `llvm.extractelement`, `llvm.insertelement` and +`llvm.shufflevector`. A [deeper dive section](#DeeperDive) discusses the current +lowering choices and tradeoffs. ### Hardware Vector Ops -Hardware Vector Ops are implemented as one dialect per target. -For internal hardware, we are auto-generating the specific HW dialects. -For `GPU`, the `NVVM` dialect adds operations such as `mma.sync`, `shfl` and -tests. -For `CPU` things are somewhat in-flight because the abstraction is close to -`LLVMIR`. The jury is still out on  whether a generic `CPU` dialect is -concretely needed, but it seems reasonable to have the same levels of -abstraction for all targets and perform cost-based lowering decisions in MLIR -even for `LLVM`. -Specialized `CPU` dialects that would capture specific features not well -captured by LLVM peephole optimizations of on different types that core MLIR -supports (e.g. Scalable Vectors) are welcome future extensions. + +Hardware Vector Ops are implemented as one dialect per target. For internal +hardware, we are auto-generating the specific HW dialects. For `GPU`, the `NVVM` +dialect adds operations such as `mma.sync`, `shfl` and tests. For `CPU` things +are somewhat in-flight because the abstraction is close to `LLVMIR`. The jury is +still out on  whether a generic `CPU` dialect is concretely needed, but it seems +reasonable to have the same levels of abstraction for all targets and perform +cost-based lowering decisions in MLIR even for `LLVM`. Specialized `CPU` +dialects that would capture specific features not well captured by LLVM peephole +optimizations of on different types that core MLIR supports (e.g. Scalable +Vectors) are welcome future extensions. ### Virtual Vector Ops -Some existing Standard and Vector Dialect on `n-D` `vector` types comprise: -``` -%2 = std.addf %0, %1 : vector<3x7x8xf32> // -> vector<3x7x8xf32> -%2 = std.mulf %0, %1 : vector<3x7x8xf32> // -> vector<3x7x8xf32> -%2 = std.splat %1 : vector<3x7x8xf32> // -> vector<3x7x8xf32> - -%1 = vector.extract %0[1]: vector<3x7x8xf32> // -> vector<7x8xf32> -%1 = vector.extract %0[1, 5]: vector<3x7x8xf32> // -> vector<8xf32> -%2 = vector.outerproduct %0, %1: vector<4xf32>, vector<8xf32> // -> vector<4x8xf32> -%3 = vector.outerproduct %0, %1, %2: vector<4xf32>, vector<8xf32> // fma when adding %2 -%3 = vector.strided_slice %0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]}: - vector<4x8x16xf32> // Returns a slice of type vector<2x2x16xf32> - -%2 = vector.transfer_read %A[%0, %1] - {permutation_map = (d0, d1) -> (d0)}: memref<7x?xf32>, vector<4xf32> - -vector.transfer_write %f1, %A[%i0, %i1, %i2, %i3] - {permutation_map = (d0, d1, d2, d3) -> (d3, d1, d0)} : - vector<5x4x3xf32>, memref -``` - -The list of Vector is currently undergoing evolutions and is best kept -track of by following the evolution of the + +Some existing Standard and Vector Dialect on `n-D` `vector` types comprise: ``` +%2 = arith.addf %0, %1 : vector<3x7x8xf32> // -> vector<3x7x8xf32> %2 = +arith.mulf %0, %1 : vector<3x7x8xf32> // -> vector<3x7x8xf32> %2 = std.splat +%1 : vector<3x7x8xf32> // -> vector<3x7x8xf32> + +%1 = vector.extract %0[1]: vector<3x7x8xf32> // -> vector<7x8xf32> %1 = +vector.extract %0[1, 5]: vector<3x7x8xf32> // -> vector<8xf32> %2 = +vector.outerproduct %0, %1: vector<4xf32>, vector<8xf32> // -> vector<4x8xf32> +%3 = vector.outerproduct %0, %1, %2: vector<4xf32>, vector<8xf32> // fma when +adding %2 %3 = vector.strided_slice %0 {offsets = [2, 2], sizes = [2, 2], +strides = [1, 1]}: vector<4x8x16xf32> // Returns a slice of type +vector<2x2x16xf32> + +%2 = vector.transfer_read %A[%0, %1] {permutation_map = (d0, d1) -> (d0)}: +memref<7x?xf32>, vector<4xf32> + +vector.transfer_write %f1, %A[%i0, %i1, %i2, %i3] {permutation_map = (d0, d1, +d2, d3) -> (d3, d1, d0)} : vector<5x4x3xf32>, memref ``` + +The list of Vector is currently undergoing evolutions and is best kept track of +by following the evolution of the [VectorOps.td](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Vector/VectorOps.td) ODS file (markdown documentation is automatically generated locally when -building and populates the [Vector -doc](https://github.com/llvm/llvm-project/blob/main/mlir/docs/Dialects/Vector.md)). Recent -extensions are driven by concrete use cases of interest. A notable such use -case is the `vector.contract` op which applies principles of the StructuredOps -abstraction to `vector` types. +building and populates the +[Vector doc](https://github.com/llvm/llvm-project/blob/main/mlir/docs/Dialects/Vector.md)). +Recent extensions are driven by concrete use cases of interest. A notable such +use case is the `vector.contract` op which applies principles of the +StructuredOps abstraction to `vector` types. ### Virtual Vector Rewrite Patterns The following rewrite patterns exist at the `VV->VV` level: -1. The now retired `MaterializeVector` pass used to legalize ops on a -coarse-grained virtual `vector` to a finer-grained virtual `vector` by -unrolling. This has been rewritten as a retargetable unroll-and-jam pattern on -`vector` ops and `vector` types. -2. The lowering of `vector_transfer` ops legalizes `vector` load/store ops to -permuted loops over scalar load/stores. This should evolve to loops over -`vector` load/stores + `mask` operations as they become available `vector` ops -at the `VV` level. - -The general direction is to add more Virtual Vector level ops and implement -more useful `VV -> VV` rewrites as composable patterns that the PatternRewrite +1. The now retired `MaterializeVector` pass used to legalize ops on a + coarse-grained virtual `vector` to a finer-grained virtual `vector` by + unrolling. This has been rewritten as a retargetable unroll-and-jam pattern + on `vector` ops and `vector` types. +2. The lowering of `vector_transfer` ops legalizes `vector` load/store ops to + permuted loops over scalar load/stores. This should evolve to loops over + `vector` load/stores + `mask` operations as they become available `vector` + ops at the `VV` level. + +The general direction is to add more Virtual Vector level ops and implement more +useful `VV -> VV` rewrites as composable patterns that the PatternRewrite infrastructure can apply iteratively. ### Virtual Vector to Hardware Vector Lowering -For now, `VV -> HWV` are specified in C++ (see for instance the -[SplatOpLowering for n-D -vectors](https://github.com/tensorflow/mlir/commit/0a0c4867c6a6fcb0a2f17ef26a791c1d551fe33d) -or the [VectorOuterProductOp -lowering](https://github.com/tensorflow/mlir/commit/957b1ca9680b4aacabb3a480fbc4ebd2506334b8)). - -Simple [conversion -tests](https://github.com/llvm/llvm-project/blob/main/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir) + +For now, `VV -> HWV` are specified in C++ (see for instance the +[SplatOpLowering for n-D vectors](https://github.com/tensorflow/mlir/commit/0a0c4867c6a6fcb0a2f17ef26a791c1d551fe33d) +or the +[VectorOuterProductOp lowering](https://github.com/tensorflow/mlir/commit/957b1ca9680b4aacabb3a480fbc4ebd2506334b8)). + +Simple +[conversion tests](https://github.com/llvm/llvm-project/blob/main/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir) are available for the `LLVM` target starting from the Virtual Vector Level. ## Rationale + ### Hardware as `vector` Machines of Minimum Granularity Higher-dimensional `vector`s are ubiquitous in modern HPC hardware. One way to think about Generic Retargetable `vector`-Level Dialect is that it operates on `vector` types that are multiples of a "good" `vector` size so the HW can -efficiently implement a set of high-level primitives -(e.g. `vector<8x8x8x16xf32>` when HW `vector` size is say `vector<4x8xf32>`). +efficiently implement a set of high-level primitives (e.g. +`vector<8x8x8x16xf32>` when HW `vector` size is say `vector<4x8xf32>`). Some notable `vector` sizes of interest include: -1. CPU: `vector`, `vector` and `vector` -2. GPU: `vector`, `vector` and -`vector` for tensor_core sizes, -3. Other accelerators: n-D `vector` as first-class citizens in the HW. +1. CPU: `vector`, `vector` and `vector` +2. GPU: `vector`, `vector` and + `vector` for tensor_core sizes, +3. Other accelerators: n-D `vector` as first-class citizens in the HW. -Depending on the target, ops on sizes that are not multiples of the HW -`vector` size may either produce slow code (e.g. by going through `LLVM` -legalization) or may not legalize at all (e.g. some unsupported accelerator X -combination of ops and types). +Depending on the target, ops on sizes that are not multiples of the HW `vector` +size may either produce slow code (e.g. by going through `LLVM` legalization) or +may not legalize at all (e.g. some unsupported accelerator X combination of ops +and types). ### Transformations Problems Avoided + A `vector<16x32x64xf32>` virtual `vector` is a coarse-grained type that can be “unrolled” to HW-specific sizes. The multi-dimensional unrolling factors are carried in the IR by the `vector` type. After unrolling, traditional instruction-level scheduling can be run. The following key transformations (along with the supporting analyses and -structural constraints) are completely avoided by operating on a ``vector`` +structural constraints) are completely avoided by operating on a `vector` `ssa-value` abstraction: -1. Loop unroll and unroll-and-jam. -2. Loop and load-store restructuring for register reuse. -3. Load to store forwarding and Mem2reg. -4. Coarsening (raising) from finer-grained `vector` form. +1. Loop unroll and unroll-and-jam. +2. Loop and load-store restructuring for register reuse. +3. Load to store forwarding and Mem2reg. +4. Coarsening (raising) from finer-grained `vector` form. Note that “unrolling” in the context of `vector`s corresponds to partial loop unroll-and-jam and not full unrolling. As a consequence this is expected to @@ -196,73 +198,71 @@ up. ### The Big Out-Of-Scope Piece: Automatic Vectorization -One important piece not discussed here is automatic vectorization -(automatically raising from scalar to n-D `vector` ops and types). The TL;DR -is that when the first "super-vectorization" prototype was implemented, MLIR -was nowhere near as mature as it is today. As we continue building more -abstractions in `VV -> HWV`, there is an opportunity to revisit vectorization -in MLIR. + +One important piece not discussed here is automatic vectorization (automatically +raising from scalar to n-D `vector` ops and types). The TL;DR is that when the +first "super-vectorization" prototype was implemented, MLIR was nowhere near as +mature as it is today. As we continue building more abstractions in `VV -> HWV`, +there is an opportunity to revisit vectorization in MLIR. Since this topic touches on codegen abstractions, it is technically out of the scope of this survey document but there is a lot to discuss in light of -structured op type representations and how a vectorization transformation can -be reused across dialects. In particular, MLIR allows the definition of -dialects at arbitrary levels of granularity and lends itself favorably to -progressive lowering. The argument can be made that automatic vectorization on -a loops + ops abstraction is akin to raising structural information that has -been lost. Instead, it is possible to revisit vectorization as simple pattern -rewrites, provided the IR is in a suitable form. For instance, vectorizing a -`linalg.generic` op whose semantics match a `matmul` can be done [quite easily -with a -pattern](https://github.com/tensorflow/mlir/commit/bff722d6b59ab99b998f0c2b9fccd0267d9f93b5). In -fact this pattern is trivial to generalize to any type of contraction when +structured op type representations and how a vectorization transformation can be +reused across dialects. In particular, MLIR allows the definition of dialects at +arbitrary levels of granularity and lends itself favorably to progressive +lowering. The argument can be made that automatic vectorization on a loops + ops +abstraction is akin to raising structural information that has been lost. +Instead, it is possible to revisit vectorization as simple pattern rewrites, +provided the IR is in a suitable form. For instance, vectorizing a +`linalg.generic` op whose semantics match a `matmul` can be done +[quite easily with a pattern](https://github.com/tensorflow/mlir/commit/bff722d6b59ab99b998f0c2b9fccd0267d9f93b5). +In fact this pattern is trivial to generalize to any type of contraction when targeting the `vector.contract` op, as well as to any field (`+/*`, `min/+`, -`max/+`, `or/and`, `logsumexp/+` ...) . In other words, by operating on a -higher level of generic abstractions than affine loops, non-trivial -transformations become significantly simpler and composable at a finer -granularity. +`max/+`, `or/and`, `logsumexp/+` ...) . In other words, by operating on a higher +level of generic abstractions than affine loops, non-trivial transformations +become significantly simpler and composable at a finer granularity. Irrespective of the existence of an auto-vectorizer, one can build a notional -vector language based on the VectorOps dialect and build end-to-end models -with expressing `vector`s in the IR directly and simple -pattern-rewrites. [EDSC](https://github.com/llvm/llvm-project/blob/main/mlir/docs/EDSC.md)s +vector language based on the VectorOps dialect and build end-to-end models with +expressing `vector`s in the IR directly and simple pattern-rewrites. +[EDSC](https://github.com/llvm/llvm-project/blob/main/mlir/docs/EDSC.md)s provide a simple way of driving such a notional language directly in C++. ## Bikeshed Naming Discussion -There are arguments against naming an n-D level of abstraction `vector` -because most people associate it with 1-D `vector`s. On the other hand, -`vector`s are first-class n-D values in MLIR. -The alternative name Tile has been proposed, which conveys higher-D -meaning. But it also is one of the most overloaded terms in compilers and -hardware. -For now, we generally use the `n-D` `vector` name and are open to better -suggestions. + +There are arguments against naming an n-D level of abstraction `vector` because +most people associate it with 1-D `vector`s. On the other hand, `vector`s are +first-class n-D values in MLIR. The alternative name Tile has been proposed, +which conveys higher-D meaning. But it also is one of the most overloaded terms +in compilers and hardware. For now, we generally use the `n-D` `vector` name and +are open to better suggestions. ## DeeperDive This section describes the tradeoffs involved in lowering the MLIR n-D vector -type and operations on it to LLVM-IR. Putting aside the [LLVM -Matrix](http://lists.llvm.org/pipermail/llvm-dev/2018-October/126871.html) -proposal for now, this assumes LLVM only has built-in support for 1-D -vector. The relationship with the LLVM Matrix proposal is discussed at the end -of this document. +type and operations on it to LLVM-IR. Putting aside the +[LLVM Matrix](http://lists.llvm.org/pipermail/llvm-dev/2018-October/126871.html) +proposal for now, this assumes LLVM only has built-in support for 1-D vector. +The relationship with the LLVM Matrix proposal is discussed at the end of this +document. MLIR does not currently support dynamic vector sizes (i.e. SVE style) so the -discussion is limited to static rank and static vector sizes -(e.g. `vector<4x8x16x32xf32>`). This section discusses operations on vectors -in LLVM and MLIR. - -LLVM instructions are prefixed by the `llvm.` dialect prefix -(e.g. `llvm.insertvalue`). Such ops operate exclusively on 1-D vectors and -aggregates following the [LLVM LangRef](https://llvm.org/docs/LangRef.html). -MLIR operations are prefixed by the `vector.` dialect prefix -(e.g. `vector.insertelement`). Such ops operate exclusively on MLIR `n-D` -`vector` types. +discussion is limited to static rank and static vector sizes (e.g. +`vector<4x8x16x32xf32>`). This section discusses operations on vectors in LLVM +and MLIR. + +LLVM instructions are prefixed by the `llvm.` dialect prefix (e.g. +`llvm.insertvalue`). Such ops operate exclusively on 1-D vectors and aggregates +following the [LLVM LangRef](https://llvm.org/docs/LangRef.html). MLIR +operations are prefixed by the `vector.` dialect prefix (e.g. +`vector.insertelement`). Such ops operate exclusively on MLIR `n-D` `vector` +types. ### Alternatives For Lowering an n-D Vector Type to LLVM -Consider a vector of rank n with static sizes `{s_0, ... s_{n-1}}` (i.e. an -MLIR `vector`). Lowering such an `n-D` MLIR vector type to -an LLVM descriptor can be done by either: + +Consider a vector of rank n with static sizes `{s_0, ... s_{n-1}}` (i.e. an MLIR +`vector`). Lowering such an `n-D` MLIR vector type to an +LLVM descriptor can be done by either: 1. Flattening to a `1-D` vector: `!llvm<"(s_0*...*s_{n-1})xfloat">` in the MLIR LLVM dialect. @@ -277,33 +277,26 @@ "k" minor dimensions. ### Constraints Inherited from LLVM (see LangRef) + The first constraint was already mentioned: LLVM only supports `1-D` `vector` -types natively. -Additional constraints are related to the difference in LLVM between vector -and aggregate types: -``` - “Aggregate Types are a subset of derived types that can contain multiple - member types. Arrays and structs are aggregate types. Vectors are not - considered to be aggregate types.”. -``` - -This distinction is also reflected in some of the operations. For `1-D` -vectors, the operations `llvm.extractelement`, `llvm.insertelement`, and +types natively. Additional constraints are related to the difference in LLVM +between vector and aggregate types: `“Aggregate Types are a subset of derived +types that can contain multiple member types. Arrays and structs are aggregate +types. Vectors are not considered to be aggregate types.”.` + +This distinction is also reflected in some of the operations. For `1-D` vectors, +the operations `llvm.extractelement`, `llvm.insertelement`, and `llvm.shufflevector` apply, with direct support for dynamic indices. For `n-D` -vectors with `n>1`, and thus aggregate types at LLVM level, the more -restrictive operations `llvm.extractvalue` and `llvm.insertvalue` apply, which -only accept static indices. There is no direct shuffling support for aggregate -types. +vectors with `n>1`, and thus aggregate types at LLVM level, the more restrictive +operations `llvm.extractvalue` and `llvm.insertvalue` apply, which only accept +static indices. There is no direct shuffling support for aggregate types. -The next sentence illustrates a recurrent tradeoff, also found in MLIR, -between “value types” (subject to SSA use-def chains) and “memory types” -(subject to aliasing and side-effects): -``` -“Structures in memory are accessed using ‘load’ and ‘store’ by getting a -pointer to a field with the llvm.getelementptr instruction. Structures in -registers are accessed using the llvm.extractvalue and llvm.insertvalue -instructions.” -``` +The next sentence illustrates a recurrent tradeoff, also found in MLIR, between +“value types” (subject to SSA use-def chains) and “memory types” (subject to +aliasing and side-effects): `“Structures in memory are accessed using ‘load’ and +‘store’ by getting a pointer to a field with the llvm.getelementptr instruction. +Structures in registers are accessed using the llvm.extractvalue and +llvm.insertvalue instructions.”` When transposing this to MLIR, `llvm.getelementptr` works on pointers to `n-D` vectors in memory. For `n-D`, vectors values that live in registers we can use @@ -320,175 +313,176 @@ are discussed in the following sections. ### Nested Aggregate + Pros: -1. Natural encoding n-D vector -> (n-1)-D aggregate over 1-D vector. -2. No need for linearization / delinearization logic inserted everywhere. -3. `llvm.insertvalue`, `llvm.extractvalue` of `(n-k)-D` aggregate is natural. -4. `llvm.insertelement`, `llvm.extractelement`, `llvm.shufflevector` over -`1-D` vector type is natural. +1. Natural encoding n-D vector -> (n-1)-D aggregate over 1-D vector. +2. No need for linearization / delinearization logic inserted everywhere. +3. `llvm.insertvalue`, `llvm.extractvalue` of `(n-k)-D` aggregate is natural. +4. `llvm.insertelement`, `llvm.extractelement`, `llvm.shufflevector` over `1-D` + vector type is natural. Cons: -1. `llvm.insertvalue` / `llvm.extractvalue` does not accept dynamic indices -but only static ones. -2. Dynamic indexing on the non-most-minor dimension requires roundtrips to -memory. -3. Special intrinsics and native instructions in LLVM operate on `1-D` -vectors. This is not expected to be a practical limitation thanks to a -`vector.cast %0: vector<4x8x16x32xf32> to vector<4x4096xf32>` operation, that -flattens the most minor dimensions (see the bigger picture in implications on -codegen). +1. `llvm.insertvalue` / `llvm.extractvalue` does not accept dynamic indices but + only static ones. +2. Dynamic indexing on the non-most-minor dimension requires roundtrips to + memory. +3. Special intrinsics and native instructions in LLVM operate on `1-D` vectors. + This is not expected to be a practical limitation thanks to a `vector.cast + %0: vector<4x8x16x32xf32> to vector<4x4096xf32>` operation, that flattens + the most minor dimensions (see the bigger picture in implications on + codegen). ### Flattened 1-D Vector Type Pros: -1. `insertelement` / `extractelement` / `shufflevector` with dynamic indexing -is possible over the whole lowered `n-D` vector type. -2. Supports special intrinsics and native operations. +1. `insertelement` / `extractelement` / `shufflevector` with dynamic indexing + is possible over the whole lowered `n-D` vector type. +2. Supports special intrinsics and native operations. -Cons: -1. Requires linearization/delinearization logic everywhere, translations are -complex. -2. Hides away the real HW structure behind dynamic indexing: at the end of the -day, HW vector sizes are generally fixed and multiple vectors will be needed -to hold a vector that is larger than the HW. -3. Unlikely peephole optimizations will result in good code: arbitrary dynamic -accesses, especially at HW vector boundaries unlikely to result in regular -patterns. +Cons: 1. Requires linearization/delinearization logic everywhere, translations +are complex. 2. Hides away the real HW structure behind dynamic indexing: at the +end of the day, HW vector sizes are generally fixed and multiple vectors will be +needed to hold a vector that is larger than the HW. 3. Unlikely peephole +optimizations will result in good code: arbitrary dynamic accesses, especially +at HW vector boundaries unlikely to result in regular patterns. ### Discussion + #### HW Vectors and Implications on the SW and the Programming Model + As of today, the LLVM model only support `1-D` vector types. This is unsurprising because historically, the vast majority of HW only supports `1-D` vector registers. We note that multiple HW vendors are in the process of evolving to higher-dimensional physical vectors. -In the following discussion, let's assume the HW vector size is `1-D` and the -SW vector size is `n-D`, with `n >= 1`. The same discussion would apply with -`2-D` HW `vector` size and `n >= 2`. In this context, most HW exhibit a vector -register file. The number of such vectors is fixed. -Depending on the rank and sizes of the SW vector abstraction and the HW vector -sizes and number of registers, an `n-D` SW vector type may be materialized by -a mix of multiple `1-D` HW vector registers + memory locations at a given -point in time. - -The implication of the physical HW constraints on the programming model are -that one cannot index dynamically across hardware registers: a register file -can generally not be indexed dynamically. This is because the register number -is fixed and one either needs to unroll explicitly to obtain fixed register -numbers or go through memory. This is a constraint familiar to CUDA -programmers: when declaring a `private float a[4]`; and subsequently indexing -with a *dynamic* value results in so-called **local memory** usage -(i.e. roundtripping to memory). +In the following discussion, let's assume the HW vector size is `1-D` and the SW +vector size is `n-D`, with `n >= 1`. The same discussion would apply with `2-D` +HW `vector` size and `n >= 2`. In this context, most HW exhibit a vector +register file. The number of such vectors is fixed. Depending on the rank and +sizes of the SW vector abstraction and the HW vector sizes and number of +registers, an `n-D` SW vector type may be materialized by a mix of multiple +`1-D` HW vector registers + memory locations at a given point in time. + +The implication of the physical HW constraints on the programming model are that +one cannot index dynamically across hardware registers: a register file can +generally not be indexed dynamically. This is because the register number is +fixed and one either needs to unroll explicitly to obtain fixed register numbers +or go through memory. This is a constraint familiar to CUDA programmers: when +declaring a `private float a[4]`; and subsequently indexing with a *dynamic* +value results in so-called **local memory** usage (i.e. roundtripping to +memory). #### Implication on codegen + MLIR `n-D` vector types are currently represented as `(n-1)-D` arrays of `1-D` -vectors when lowered to LLVM. -This introduces the consequences on static vs dynamic indexing discussed -previously: `extractelement`, `insertelement` and `shufflevector` on `n-D` -vectors in MLIR only support static indices. Dynamic indices are only -supported on the most minor `1-D` vector but not the outer `(n-1)-D`. -For other cases, explicit load / stores are required. +vectors when lowered to LLVM. This introduces the consequences on static vs +dynamic indexing discussed previously: `extractelement`, `insertelement` and +`shufflevector` on `n-D` vectors in MLIR only support static indices. Dynamic +indices are only supported on the most minor `1-D` vector but not the outer +`(n-1)-D`. For other cases, explicit load / stores are required. The implications on codegen are as follows: -1. Loops around `vector` values are indirect addressing of vector values, they -must operate on explicit load / store operations over `n-D` vector types. -2. Once an `n-D` `vector` type is loaded into an SSA value (that may or may -not live in `n` registers, with or without spilling, when eventually lowered), -it may be unrolled to smaller `k-D` `vector` types and operations that -correspond to the HW. This level of MLIR codegen is related to register -allocation and spilling that occur much later in the LLVM pipeline. -3. HW may support >1-D vectors with intrinsics for indirect addressing within -these vectors. These can be targeted thanks to explicit `vector_cast` -operations from MLIR `k-D` vector types and operations to LLVM `1-D` vectors + -intrinsics. - -Alternatively, we argue that directly lowering to a linearized abstraction -hides away the codegen complexities related to memory accesses by giving a -false impression of magical dynamic indexing across registers. Instead we -prefer to make those very explicit in MLIR and allow codegen to explore -tradeoffs. -Different HW will require different tradeoffs in the sizes involved in steps -1., 2. and 3. - -Decisions made at the MLIR level will have implications at a much later stage -in LLVM (after register allocation). We do not envision to expose concerns -related to modeling of register allocation and spilling to MLIR -explicitly. Instead, each target will expose a set of "good" target operations -and `n-D` vector types, associated with costs that `PatterRewriters` at the -MLIR level will be able to target. Such costs at the MLIR level will be -abstract and used for ranking, not for accurate performance modeling. In the -future such costs will be learned. +1. Loops around `vector` values are indirect addressing of vector values, they + must operate on explicit load / store operations over `n-D` vector types. +2. Once an `n-D` `vector` type is loaded into an SSA value (that may or may not + live in `n` registers, with or without spilling, when eventually lowered), + it may be unrolled to smaller `k-D` `vector` types and operations that + correspond to the HW. This level of MLIR codegen is related to register + allocation and spilling that occur much later in the LLVM pipeline. +3. HW may support >1-D vectors with intrinsics for indirect addressing within + these vectors. These can be targeted thanks to explicit `vector_cast` + operations from MLIR `k-D` vector types and operations to LLVM `1-D` + vectors + intrinsics. + +Alternatively, we argue that directly lowering to a linearized abstraction hides +away the codegen complexities related to memory accesses by giving a false +impression of magical dynamic indexing across registers. Instead we prefer to +make those very explicit in MLIR and allow codegen to explore tradeoffs. +Different HW will require different tradeoffs in the sizes involved in steps 1., +2. and 3. + +Decisions made at the MLIR level will have implications at a much later stage in +LLVM (after register allocation). We do not envision to expose concerns related +to modeling of register allocation and spilling to MLIR explicitly. Instead, +each target will expose a set of "good" target operations and `n-D` vector +types, associated with costs that `PatterRewriters` at the MLIR level will be +able to target. Such costs at the MLIR level will be abstract and used for +ranking, not for accurate performance modeling. In the future such costs will be +learned. #### Implication on Lowering to Accelerators -To target accelerators that support higher dimensional vectors natively, we -can start from either `1-D` or `n-D` vectors in MLIR and use `vector.cast` to + +To target accelerators that support higher dimensional vectors natively, we can +start from either `1-D` or `n-D` vectors in MLIR and use `vector.cast` to flatten the most minor dimensions to `1-D` `vector` where `K` is an appropriate constant. Then, the existing lowering to LLVM-IR immediately applies, with extensions for accelerator-specific intrinsics. It is the role of an Accelerator-specific vector dialect (see codegen flow in -the figure above) to lower the `vector.cast`. Accelerator -> LLVM lowering -would then consist of a bunch of `Accelerator -> Accelerator` rewrites to -perform the casts composed with `Accelerator -> LLVM` conversions + intrinsics -that operate on `1-D` `vector`. +the figure above) to lower the `vector.cast`. Accelerator -> LLVM lowering would +then consist of a bunch of `Accelerator -> Accelerator` rewrites to perform the +casts composed with `Accelerator -> LLVM` conversions + intrinsics that operate +on `1-D` `vector`. Some of those rewrites may need extra handling, especially if a reduction is -involved. For example, `vector.cast %0: vector to -vector` when `K != K1 * … * Kn` and some arbitrary irregular -`vector.cast %0: vector<4x4x17xf32> to vector` may introduce masking -and intra-vector shuffling that may not be worthwhile or even feasible, -i.e. infinite cost. +involved. For example, `vector.cast %0: vector to vector` +when `K != K1 * … * Kn` and some arbitrary irregular `vector.cast %0: +vector<4x4x17xf32> to vector` may introduce masking and intra-vector +shuffling that may not be worthwhile or even feasible, i.e. infinite cost. -However `vector.cast %0: vector to vector` when `K = -K1 * … * Kn` should be close to a noop. +However `vector.cast %0: vector to vector` when `K = K1 * +… * Kn` should be close to a noop. As we start building accelerator-specific abstractions, we hope to achieve -retargetable codegen: the same infra is used for CPU, GPU and accelerators -with extra MLIR patterns and costs. +retargetable codegen: the same infra is used for CPU, GPU and accelerators with +extra MLIR patterns and costs. #### Implication on calling external functions that operate on vectors + It is possible (likely) that we additionally need to linearize when calling an external function. ### Relationship to LLVM matrix type proposal. + The LLVM matrix proposal was formulated 1 year ago but seemed to be somewhat stalled until recently. In its current form, it is limited to 2-D matrix types -and operations are implemented with LLVM intrinsics. -In contrast, MLIR sits at a higher level of abstraction and allows the -lowering of generic operations on generic n-D vector types from MLIR to -aggregates of 1-D LLVM vectors. -In the future, it could make sense to lower to the LLVM matrix abstraction -also for CPU even though MLIR will continue needing higher level abstractions. - -On the other hand, one should note that as MLIR is moving to LLVM, this -document could become the unifying abstraction that people should target for ->1-D vectors and the LLVM matrix proposal can be viewed as a subset of this -work. +and operations are implemented with LLVM intrinsics. In contrast, MLIR sits at a +higher level of abstraction and allows the lowering of generic operations on +generic n-D vector types from MLIR to aggregates of 1-D LLVM vectors. In the +future, it could make sense to lower to the LLVM matrix abstraction also for CPU +even though MLIR will continue needing higher level abstractions. + +On the other hand, one should note that as MLIR is moving to LLVM, this document +could become the unifying abstraction that people should target for + +> 1-D vectors and the LLVM matrix proposal can be viewed as a subset of this +> work. ### Conclusion + The flattened 1-D vector design in the LLVM matrix proposal is good in a HW-specific world with special intrinsics. This is a good abstraction for register allocation, Instruction-Level-Parallelism and SoftWare-Pipelining/Modulo Scheduling optimizations at the register level. -However MLIR codegen operates at a higher level of abstraction where we want -to target operations on coarser-grained vectors than the HW size and on which +However MLIR codegen operates at a higher level of abstraction where we want to +target operations on coarser-grained vectors than the HW size and on which unroll-and-jam is applied and patterns across multiple HW vectors can be matched. This makes “nested aggregate type of 1-D vector” an appealing abstraction for lowering from MLIR because: -1. it does not hide complexity related to the buffer vs value semantics and -the memory subsystem and -2. it does not rely on LLVM to magically make all the things work from a too -low-level abstraction. +1. it does not hide complexity related to the buffer vs value semantics and the + memory subsystem and +2. it does not rely on LLVM to magically make all the things work from a too + low-level abstraction. -The use of special intrinsics in a `1-D` LLVM world is still available thanks -to an explicit `vector.cast` op. +The use of special intrinsics in a `1-D` LLVM world is still available thanks to +an explicit `vector.cast` op. ## Operations diff --git a/mlir/docs/Dialects/emitc.md b/mlir/docs/Dialects/emitc.md --- a/mlir/docs/Dialects/emitc.md +++ b/mlir/docs/Dialects/emitc.md @@ -1,35 +1,37 @@ -The EmitC dialect allows to convert operations from other MLIR dialects to -EmitC ops. Those can be translated to C/C++ via the Cpp emitter. +The EmitC dialect allows to convert operations from other MLIR dialects to EmitC +ops. Those can be translated to C/C++ via the Cpp emitter. The following convention is followed: -* If template arguments are passed to an `emitc.call` operation, - C++ is generated. -* If tensors are used, C++ is generated. -* If multiple return values are used within in a functions or an - `emitc.call` operation, C++11 is required. -* If floating-point type template arguments are passed to an `emitc.call` - operation, C++20 is required. -* Else the generated code is compatible with C99. +* If template arguments are passed to an `emitc.call` operation, C++ is + generated. +* If tensors are used, C++ is generated. +* If multiple return values are used within in a functions or an `emitc.call` + operation, C++11 is required. +* If floating-point type template arguments are passed to an `emitc.call` + operation, C++20 is required. +* Else the generated code is compatible with C99. These restrictions are neither inherent to the EmitC dialect itself nor to the Cpp emitter and therefore need to be considered while implementing conversions. After the conversion, C/C++ code can be emitted with `mlir-translate`. The tool -supports translating MLIR to C/C++ by passing `-mlir-to-cpp`. -Furthermore, code with variables declared at top can be generated by passing -the additional argument `-declare-variables-at-top`. +supports translating MLIR to C/C++ by passing `-mlir-to-cpp`. Furthermore, code +with variables declared at top can be generated by passing the additional +argument `-declare-variables-at-top`. Besides operations part of the EmitC dialect, the Cpp targets supports translating the following operations: -* 'std' Dialect - * `std.br` - * `std.call` - * `std.cond_br` - * `std.constant` - * `std.return` -* 'scf' Dialect - * `scf.for` - * `scf.if` - * `scf.yield` +* 'std' Dialect + * `std.br` + * `std.call` + * `std.cond_br` + * `std.constant` + * `std.return` +* 'scf' Dialect + * `scf.for` + * `scf.if` + * `scf.yield` +* 'arith' Dialect + * 'arith.constant' diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -11,17 +11,17 @@ continuous design provides a framework to lower from dataflow graphs to high-performance target-specific code. -This document defines and describes the key concepts in MLIR, and is intended -to be a dry reference document - the [rationale -documentation](Rationale/Rationale.md), +This document defines and describes the key concepts in MLIR, and is intended to +be a dry reference document - the +[rationale documentation](Rationale/Rationale.md), [glossary](../getting_started/Glossary.md), and other content are hosted elsewhere. MLIR is designed to be used in three different forms: a human-readable textual form suitable for debugging, an in-memory form suitable for programmatic -transformations and analysis, and a compact serialized form suitable for -storage and transport. The different forms all describe the same semantic -content. This document describes the human-readable textual form. +transformations and analysis, and a compact serialized form suitable for storage +and transport. The different forms all describe the same semantic content. This +document describes the human-readable textual form. [TOC] @@ -29,34 +29,31 @@ MLIR is fundamentally based on a graph-like data structure of nodes, called *Operations*, and edges, called *Values*. Each Value is the result of exactly -one Operation or Block Argument, and has a *Value Type* defined by the [type -system](#type-system). [Operations](#operations) are contained in +one Operation or Block Argument, and has a *Value Type* defined by the +[type system](#type-system). [Operations](#operations) are contained in [Blocks](#blocks) and Blocks are contained in [Regions](#regions). Operations are also ordered within their containing block and Blocks are ordered in their -containing region, although this order may or may not be semantically -meaningful in a given [kind of region](Interfaces.md/#regionkindinterfaces)). -Operations may also contain regions, enabling hierarchical structures to be -represented. +containing region, although this order may or may not be semantically meaningful +in a given [kind of region](Interfaces.md/#regionkindinterfaces)). Operations +may also contain regions, enabling hierarchical structures to be represented. Operations can represent many different concepts, from higher-level concepts -like function definitions, function calls, buffer allocations, view or slices -of buffers, and process creation, to lower-level concepts like -target-independent arithmetic, target-specific instructions, configuration -registers, and logic gates. These different concepts are represented by -different operations in MLIR and the set of operations usable in MLIR can be -arbitrarily extended. +like function definitions, function calls, buffer allocations, view or slices of +buffers, and process creation, to lower-level concepts like target-independent +arithmetic, target-specific instructions, configuration registers, and logic +gates. These different concepts are represented by different operations in MLIR +and the set of operations usable in MLIR can be arbitrarily extended. MLIR also provides an extensible framework for transformations on operations, using familiar concepts of compiler [Passes](Passes.md). Enabling an arbitrary -set of passes on an arbitrary set of operations results in a significant -scaling challenge, since each transformation must potentially take into -account the semantics of any operation. MLIR addresses this complexity by -allowing operation semantics to be described abstractly using -[Traits](Traits.md) and [Interfaces](Interfaces.md), enabling transformations -to operate on operations more generically. Traits often describe verification -constraints on valid IR, enabling complex invariants to be captured and -checked. (see [Op vs -Operation](Tutorials/Toy/Ch-2.md/#op-vs-operation-using-mlir-operations)) +set of passes on an arbitrary set of operations results in a significant scaling +challenge, since each transformation must potentially take into account the +semantics of any operation. MLIR addresses this complexity by allowing operation +semantics to be described abstractly using [Traits](Traits.md) and +[Interfaces](Interfaces.md), enabling transformations to operate on operations +more generically. Traits often describe verification constraints on valid IR, +enabling complex invariants to be captured and checked. (see +[Op vs Operation](Tutorials/Toy/Ch-2.md/#op-vs-operation-using-mlir-operations)) One obvious application of MLIR is to represent an [SSA-based](https://en.wikipedia.org/wiki/Static_single_assignment_form) IR, @@ -76,26 +73,26 @@ // known. The shapes are assumed to match. func @mul(%A: tensor<100x?xf32>, %B: tensor) -> (tensor<100x50xf32>) { // Compute the inner dimension of %A using the dim operation. - %n = dim %A, 1 : tensor<100x?xf32> + %n = memref.dim %A, 1 : tensor<100x?xf32> // Allocate addressable "buffers" and copy tensors %A and %B into them. - %A_m = alloc(%n) : memref<100x?xf32> - tensor_store %A to %A_m : memref<100x?xf32> + %A_m = memref.alloc(%n) : memref<100x?xf32> + memref.tensor_store %A to %A_m : memref<100x?xf32> - %B_m = alloc(%n) : memref - tensor_store %B to %B_m : memref + %B_m = memref.alloc(%n) : memref + memref.tensor_store %B to %B_m : memref // Call function @multiply passing memrefs as arguments, // and getting returned the result of the multiplication. %C_m = call @multiply(%A_m, %B_m) : (memref<100x?xf32>, memref) -> (memref<100x50xf32>) - dealloc %A_m : memref<100x?xf32> - dealloc %B_m : memref + memref.dealloc %A_m : memref<100x?xf32> + memref.dealloc %B_m : memref // Load the buffer data into a higher level "tensor" value. - %C = tensor_load %C_m : memref<100x50xf32> - dealloc %C_m : memref<100x50xf32> + %C = memref.tensor_load %C_m : memref<100x50xf32> + memref.dealloc %C_m : memref<100x50xf32> // Call TensorFlow built-in function to print the result tensor. "tf.Print"(%C){message: "mul result"} @@ -108,22 +105,22 @@ func @multiply(%A: memref<100x?xf32>, %B: memref) -> (memref<100x50xf32>) { // Compute the inner dimension of %A. - %n = dim %A, 1 : memref<100x?xf32> + %n = memref.dim %A, 1 : memref<100x?xf32> // Allocate memory for the multiplication result. - %C = alloc() : memref<100x50xf32> + %C = memref.alloc() : memref<100x50xf32> // Multiplication loop nest. affine.for %i = 0 to 100 { affine.for %j = 0 to 50 { - store 0 to %C[%i, %j] : memref<100x50xf32> + memref.store 0 to %C[%i, %j] : memref<100x50xf32> affine.for %k = 0 to %n { - %a_v = load %A[%i, %k] : memref<100x?xf32> - %b_v = load %B[%k, %j] : memref - %prod = mulf %a_v, %b_v : f32 - %c_v = load %C[%i, %j] : memref<100x50xf32> - %sum = addf %c_v, %prod : f32 - store %sum, %C[%i, %j] : memref<100x50xf32> + %a_v = memref.load %A[%i, %k] : memref<100x?xf32> + %b_v = memref.load %B[%k, %j] : memref + %prod = arith.mulf %a_v, %b_v : f32 + %c_v = memref.load %C[%i, %j] : memref<100x50xf32> + %sum = arith.addf %c_v, %prod : f32 + memref.store %sum, %C[%i, %j] : memref<100x50xf32> } } } @@ -134,9 +131,9 @@ ## Notation MLIR has a simple and unambiguous grammar, allowing it to reliably round-trip -through a textual form. This is important for development of the compiler - -e.g. for understanding the state of code as it is being transformed and -writing test cases. +through a textual form. This is important for development of the compiler - e.g. +for understanding the state of code as it is being transformed and writing test +cases. This document describes the grammar using [Extended Backus-Naur Form (EBNF)](https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_form). @@ -201,12 +198,12 @@ value-use-list ::= value-use (`,` value-use)* ``` -Identifiers name entities such as values, types and functions, and are -chosen by the writer of MLIR code. Identifiers may be descriptive (e.g. -`%batch_size`, `@matmul`), or may be non-descriptive when they are -auto-generated (e.g. `%23`, `@func42`). Identifier names for values may be -used in an MLIR text file but are not persisted as part of the IR - the printer -will give them anonymous names like `%42`. +Identifiers name entities such as values, types and functions, and are chosen by +the writer of MLIR code. Identifiers may be descriptive (e.g. `%batch_size`, +`@matmul`), or may be non-descriptive when they are auto-generated (e.g. `%23`, +`@func42`). Identifier names for values may be used in an MLIR text file but are +not persisted as part of the IR - the printer will give them anonymous names +like `%42`. MLIR guarantees identifiers never collide with keywords by prefixing identifiers with a sigil (e.g. `%`, `#`, `@`, `^`, `!`). In certain unambiguous contexts @@ -214,22 +211,20 @@ keywords may be added to future versions of MLIR without danger of collision with existing identifiers. -Value identifiers are only [in scope](#value-scoping) for the (nested) -region in which they are defined and cannot be accessed or referenced -outside of that region. Argument identifiers in mapping functions are -in scope for the mapping body. Particular operations may further limit -which identifiers are in scope in their regions. For instance, the -scope of values in a region with [SSA control flow -semantics](#control-flow-and-ssacfg-regions) is constrained according -to the standard definition of [SSA -dominance](https://en.wikipedia.org/wiki/Dominator_\(graph_theory\)). Another -example is the [IsolatedFromAbove trait](Traits.md/#isolatedfromabove), -which restricts directly accessing values defined in containing -regions. +Value identifiers are only [in scope](#value-scoping) for the (nested) region in +which they are defined and cannot be accessed or referenced outside of that +region. Argument identifiers in mapping functions are in scope for the mapping +body. Particular operations may further limit which identifiers are in scope in +their regions. For instance, the scope of values in a region with +[SSA control flow semantics](#control-flow-and-ssacfg-regions) is constrained +according to the standard definition of +[SSA dominance](https://en.wikipedia.org/wiki/Dominator_\(graph_theory\)). +Another example is the [IsolatedFromAbove trait](Traits.md/#isolatedfromabove), +which restricts directly accessing values defined in containing regions. Function identifiers and mapping identifiers are associated with -[Symbols](SymbolsAndSymbolTables.md) and have scoping rules dependent on -symbol attributes. +[Symbols](SymbolsAndSymbolTables.md) and have scoping rules dependent on symbol +attributes. ## Dialects @@ -260,9 +255,9 @@ operations directly through to MLIR. As an example, some targets go through LLVM. LLVM has a rich set of intrinsics for certain target-independent operations (e.g. addition with overflow check) as well as providing access to -target-specific operations for the targets it supports (e.g. vector -permutation operations). LLVM intrinsics in MLIR are represented via -operations that start with an "llvm." name. +target-specific operations for the targets it supports (e.g. vector permutation +operations). LLVM intrinsics in MLIR are represented via operations that start +with an "llvm." name. Example: @@ -293,21 +288,21 @@ trailing-location ::= (`loc` `(` location `)`)? ``` -MLIR introduces a uniform concept called _operations_ to enable describing -many different levels of abstractions and computations. Operations in MLIR are -fully extensible (there is no fixed list of operations) and have -application-specific semantics. For example, MLIR supports [target-independent -operations](Dialects/Standard.md#memory-operations), [affine -operations](Dialects/Affine.md), and [target-specific machine -operations](#target-specific-operations). +MLIR introduces a uniform concept called *operations* to enable describing many +different levels of abstractions and computations. Operations in MLIR are fully +extensible (there is no fixed list of operations) and have application-specific +semantics. For example, MLIR supports +[target-independent operations](Dialects/Standard.md#memory-operations), +[affine operations](Dialects/Affine.md), and +[target-specific machine operations](#target-specific-operations). The internal representation of an operation is simple: an operation is identified by a unique string (e.g. `dim`, `tf.Conv2d`, `x86.repmovsb`, -`ppc.eieio`, etc), can return zero or more results, take zero or more -operands, has a dictionary of [attributes](#attributes), has zero or more -successors, and zero or more enclosed [regions](#regions). The generic printing -form includes all these elements literally, with a function type to indicate the -types of the results and operands. +`ppc.eieio`, etc), can return zero or more results, take zero or more operands, +has a dictionary of [attributes](#attributes), has zero or more successors, and +zero or more enclosed [regions](#regions). The generic printing form includes +all these elements literally, with a function type to indicate the types of the +results and operands. Example: @@ -325,7 +320,7 @@ ``` In addition to the basic syntax above, dialects may register known operations. -This allows those dialects to support _custom assembly form_ for parsing and +This allows those dialects to support *custom assembly form* for parsing and printing operations. In the operation sets listed below, we show both forms. ### Builtin Operations @@ -352,27 +347,27 @@ block-arg-list ::= `(` value-id-and-type-list? `)` ``` -A *Block* is a list of operations. In [SSACFG -regions](#control-flow-and-ssacfg-regions), each block represents a compiler -[basic block](https://en.wikipedia.org/wiki/Basic_block) where instructions -inside the block are executed in order and terminator operations implement -control flow branches between basic blocks. - -A region with a single block may not include a [terminator -operation](#terminator-operations). The enclosing op can opt-out of this -requirement with the `NoTerminator` trait. The top-level `ModuleOp` is an -example of such operation which defined this trait and whose block body does -not have a terminator. - -Blocks in MLIR take a list of block arguments, notated in a function-like -way. Block arguments are bound to values specified by the semantics of -individual operations. Block arguments of the entry block of a region are also -arguments to the region and the values bound to these arguments are determined -by the semantics of the containing operation. Block arguments of other blocks -are determined by the semantics of terminator operations, e.g. Branches, which -have the block as a successor. In regions with [control -flow](#control-flow-and-ssacfg-regions), MLIR leverages this structure to -implicitly represent the passage of control-flow dependent values without the +A *Block* is a list of operations. In +[SSACFG regions](#control-flow-and-ssacfg-regions), each block represents a +compiler [basic block](https://en.wikipedia.org/wiki/Basic_block) where +instructions inside the block are executed in order and terminator operations +implement control flow branches between basic blocks. + +A region with a single block may not include a +[terminator operation](#terminator-operations). The enclosing op can opt-out of +this requirement with the `NoTerminator` trait. The top-level `ModuleOp` is an +example of such operation which defined this trait and whose block body does not +have a terminator. + +Blocks in MLIR take a list of block arguments, notated in a function-like way. +Block arguments are bound to values specified by the semantics of individual +operations. Block arguments of the entry block of a region are also arguments to +the region and the values bound to these arguments are determined by the +semantics of the containing operation. Block arguments of other blocks are +determined by the semantics of terminator operations, e.g. Branches, which have +the block as a successor. In regions with +[control flow](#control-flow-and-ssacfg-regions), MLIR leverages this structure +to implicitly represent the passage of control-flow dependent values without the complex nuances of PHI nodes in traditional SSA representations. Note that values which are not control-flow dependent can be referenced directly and do not need to be passed through block arguments. @@ -389,7 +384,7 @@ br ^bb3(%a: i64) // Branch passes %a as the argument ^bb2: - %b = addi %a, %a : i64 + %b = arith.addi %a, %a : i64 br ^bb3(%b: i64) // Branch passes %b as the argument // ^bb3 receives an argument, named %c, from predecessors @@ -400,21 +395,20 @@ br ^bb4(%c, %a : i64, i64) ^bb4(%d : i64, %e : i64): - %0 = addi %d, %e : i64 + %0 = arith.addi %d, %e : i64 return %0 : i64 // Return is also a terminator. } ``` -**Context:** The "block argument" representation eliminates a number -of special cases from the IR compared to traditional "PHI nodes are -operations" SSA IRs (like LLVM). For example, the [parallel copy -semantics](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.524.5461&rep=rep1&type=pdf) -of SSA is immediately apparent, and function arguments are no longer a -special case: they become arguments to the entry block [[more -rationale](Rationale/Rationale.md/#block-arguments-vs-phi-nodes)]. Blocks -are also a fundamental concept that cannot be represented by -operations because values defined in an operation cannot be accessed -outside the operation. +**Context:** The "block argument" representation eliminates a number of special +cases from the IR compared to traditional "PHI nodes are operations" SSA IRs +(like LLVM). For example, the +[parallel copy semantics](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.524.5461&rep=rep1&type=pdf) +of SSA is immediately apparent, and function arguments are no longer a special +case: they become arguments to the entry block +[[more rationale](Rationale/Rationale.md/#block-arguments-vs-phi-nodes)]. Blocks +are also a fundamental concept that cannot be represented by operations because +values defined in an operation cannot be accessed outside the operation. ## Regions @@ -425,16 +419,15 @@ semantics of the regions it contains. MLIR currently defines two kinds of regions: [SSACFG regions](#control-flow-and-ssacfg-regions), which describe control flow between blocks, and [Graph regions](#graph-regions), which do not -require control flow between block. The kinds of regions within an operation -are described using the -[RegionKindInterface](Interfaces.md/#regionkindinterfaces). +require control flow between block. The kinds of regions within an operation are +described using the [RegionKindInterface](Interfaces.md/#regionkindinterfaces). -Regions do not have a name or an address, only the blocks contained in a -region do. Regions must be contained within operations and have no type or -attributes. The first block in the region is a special block called the 'entry -block'. The arguments to the entry block are also the arguments of the region -itself. The entry block cannot be listed as a successor of any other -block. The syntax for a region is as follows: +Regions do not have a name or an address, only the blocks contained in a region +do. Regions must be contained within operations and have no type or attributes. +The first block in the region is a special block called the 'entry block'. The +arguments to the entry block are also the arguments of the region itself. The +entry block cannot be listed as a successor of any other block. The syntax for a +region is as follows: ``` region ::= `{` block* `}` @@ -444,21 +437,20 @@ has additional semantic restrictions that other types of regions may not have. For example, in a function body, block terminators must either branch to a different block, or return from a function where the types of the `return` -arguments must match the result types of the function signature. Similarly, -the function arguments must match the types and count of the region arguments. -In general, operations with regions can define these correspondances -arbitrarily. +arguments must match the result types of the function signature. Similarly, the +function arguments must match the types and count of the region arguments. In +general, operations with regions can define these correspondances arbitrarily. ### Value Scoping Regions provide hierarchical encapsulation of programs: it is impossible to -reference, i.e. branch to, a block which is not in the same region as the -source of the reference, i.e. a terminator operation. Similarly, regions -provides a natural scoping for value visibility: values defined in a region -don't escape to the enclosing region, if any. By default, operations inside a -region can reference values defined outside of the region whenever it would -have been legal for operands of the enclosing operation to reference those -values, but this can be restricted using traits, such as +reference, i.e. branch to, a block which is not in the same region as the source +of the reference, i.e. a terminator operation. Similarly, regions provides a +natural scoping for value visibility: values defined in a region don't escape to +the enclosing region, if any. By default, operations inside a region can +reference values defined outside of the region whenever it would have been legal +for operands of the enclosing operation to reference those values, but this can +be restricted using traits, such as [OpTrait::IsolatedFromAbove](Traits.md/#isolatedfromabove), or a custom verifier. @@ -466,56 +458,54 @@ ```mlir "any_op"(%a) ({ // if %a is in-scope in the containing region... - // then %a is in-scope here too. + // then %a is in-scope here too. %new_value = "another_op"(%a) : (i64) -> (i64) }) : (i64) -> (i64) ``` -MLIR defines a generalized 'hierarchical dominance' concept that operates -across hierarchy and defines whether a value is 'in scope' and can be used by -a particular operation. Whether a value can be used by another operation in -the same region is defined by the kind of region. A value defined in a region -can be used by an operation which has a parent in the same region, if and only -if the parent could use the value. A value defined by an argument to a region -can always be used by any operation deeply contained in the region. A value -defined in a region can never be used outside of the region. +MLIR defines a generalized 'hierarchical dominance' concept that operates across +hierarchy and defines whether a value is 'in scope' and can be used by a +particular operation. Whether a value can be used by another operation in the +same region is defined by the kind of region. A value defined in a region can be +used by an operation which has a parent in the same region, if and only if the +parent could use the value. A value defined by an argument to a region can +always be used by any operation deeply contained in the region. A value defined +in a region can never be used outside of the region. ### Control Flow and SSACFG Regions In MLIR, control flow semantics of a region is indicated by -[RegionKind::SSACFG](Interfaces.md/#regionkindinterfaces). Informally, these -regions support semantics where operations in a region 'execute -sequentially'. Before an operation executes, its operands have well-defined -values. After an operation executes, the operands have the same values and -results also have well-defined values. After an operation executes, the next -operation in the block executes until the operation is the terminator operation -at the end of a block, in which case some other operation will execute. The -determination of the next instruction to execute is the 'passing of control -flow'. - -In general, when control flow is passed to an operation, MLIR does not -restrict when control flow enters or exits the regions contained in that -operation. However, when control flow enters a region, it always begins in the -first block of the region, called the *entry* block. Terminator operations -ending each block represent control flow by explicitly specifying the -successor blocks of the block. Control flow can only pass to one of the -specified successor blocks as in a `branch` operation, or back to the -containing operation as in a `return` operation. Terminator operations without -successors can only pass control back to the containing operation. Within -these restrictions, the particular semantics of terminator operations is -determined by the specific dialect operations involved. Blocks (other than the -entry block) that are not listed as a successor of a terminator operation are -defined to be unreachable and can be removed without affecting the semantics -of the containing operation. +[RegionKind::SSACFG](Interfaces.md/#regionkindinterfaces). Informally, these +regions support semantics where operations in a region 'execute sequentially'. +Before an operation executes, its operands have well-defined values. After an +operation executes, the operands have the same values and results also have +well-defined values. After an operation executes, the next operation in the +block executes until the operation is the terminator operation at the end of a +block, in which case some other operation will execute. The determination of the +next instruction to execute is the 'passing of control flow'. + +In general, when control flow is passed to an operation, MLIR does not restrict +when control flow enters or exits the regions contained in that operation. +However, when control flow enters a region, it always begins in the first block +of the region, called the *entry* block. Terminator operations ending each block +represent control flow by explicitly specifying the successor blocks of the +block. Control flow can only pass to one of the specified successor blocks as in +a `branch` operation, or back to the containing operation as in a `return` +operation. Terminator operations without successors can only pass control back +to the containing operation. Within these restrictions, the particular semantics +of terminator operations is determined by the specific dialect operations +involved. Blocks (other than the entry block) that are not listed as a successor +of a terminator operation are defined to be unreachable and can be removed +without affecting the semantics of the containing operation. Although control flow always enters a region through the entry block, control flow may exit a region through any block with an appropriate terminator. The standard dialect leverages this capability to define operations with Single-Entry-Multiple-Exit (SEME) regions, possibly flowing through different -blocks in the region and exiting through any block with a `return` -operation. This behavior is similar to that of a function body in most -programming languages. In addition, control flow may also not reach the end of -a block or region, for example if a function call does not return. +blocks in the region and exiting through any block with a `return` operation. +This behavior is similar to that of a function body in most programming +languages. In addition, control flow may also not reach the end of a block or +region, for example if a function call does not return. Example: @@ -548,14 +538,14 @@ An operation containing multiple regions also completely determines the semantics of those regions. In particular, when control flow is passed to an operation, it may transfer control flow to any contained region. When control -flow exits a region and is returned to the containing operation, the -containing operation may pass control flow to any region in the same -operation. An operation may also pass control flow to multiple contained -regions concurrently. An operation may also pass control flow into regions -that were specified in other operations, in particular those that defined the -values or symbols the given operation uses as in a call operation. This -passage of control is generally independent of passage of control flow through -the basic blocks of the containing region. +flow exits a region and is returned to the containing operation, the containing +operation may pass control flow to any region in the same operation. An +operation may also pass control flow to multiple contained regions concurrently. +An operation may also pass control flow into regions that were specified in +other operations, in particular those that defined the values or symbols the +given operation uses as in a call operation. This passage of control is +generally independent of passage of control flow through the basic blocks of the +containing region. #### Closure @@ -579,19 +569,19 @@ completely determined by its containing operation. Graph regions may only contain a single basic block (the entry block). -**Rationale:** Currently graph regions are arbitrarily limited to a single -basic block, although there is no particular semantic reason for this -limitation. This limitation has been added to make it easier to stabilize the -pass infrastructure and commonly used passes for processing graph regions to -properly handle feedback loops. Multi-block regions may be allowed in the -future if use cases that require it arise. +**Rationale:** Currently graph regions are arbitrarily limited to a single basic +block, although there is no particular semantic reason for this limitation. This +limitation has been added to make it easier to stabilize the pass infrastructure +and commonly used passes for processing graph regions to properly handle +feedback loops. Multi-block regions may be allowed in the future if use cases +that require it arise. In graph regions, MLIR operations naturally represent nodes, while each MLIR value represents a multi-edge connecting a single source node and multiple -destination nodes. All values defined in the region as results of operations -are in scope within the region and can be accessed by any other operation in -the region. In graph regions, the order of operations within a block and the -order of blocks in a region is not semantically meaningful and non-terminator +destination nodes. All values defined in the region as results of operations are +in scope within the region and can be accessed by any other operation in the +region. In graph regions, the order of operations within a block and the order +of blocks in a region is not semantically meaningful and non-terminator operations may be freely reordered, for instance, by canonicalization. Other kinds of graphs, such as graphs with multiple source nodes and multiple destination nodes, can also be represented by representing graph edges as MLIR @@ -604,7 +594,7 @@ "test.graph_region"() ({ // A Graph region %1 = "op1"(%1, %3) : (i32, i32) -> (i32) // OK: %1, %3 allowed here %2 = "test.ssacfg_region"() ({ - %5 = "op2"(%1, %2, %3, %4) : (i32, i32, i32, i32) -> (i32) // OK: %1, %2, %3, %4 all defined in the containing region + %5 = "op2"(%1, %2, %3, %4) : (i32, i32, i32, i32) -> (i32) // OK: %1, %2, %3, %4 all defined in the containing region }) : () -> (i32) %3 = "op2"(%1, %4) : (i32, i32) -> (i32) // OK: %4 allowed here %4 = "op3"(%1) : (i32) -> (i32) @@ -754,16 +744,17 @@ semantics. The attribute entries are considered to be of two different kinds based on whether their dictionary key has a dialect prefix: -- *inherent attributes* are inherent to the definition of an operation's - semantics. The operation itself is expected to verify the consistency of these - attributes. An example is the `predicate` attribute of the `std.cmpi` op. - These attributes must have names that do not start with a dialect prefix. - -- *discardable attributes* have semantics defined externally to the operation - itself, but must be compatible with the operations's semantics. These - attributes must have names that start with a dialect prefix. The dialect - indicated by the dialect prefix is expected to verify these attributes. An - example is the `gpu.container_module` attribute. +- *inherent attributes* are inherent to the definition of an operation's + semantics. The operation itself is expected to verify the consistency of + these attributes. An example is the `predicate` attribute of the + `arith.cmpi` op. These attributes must have names that do not start with a + dialect prefix. + +- *discardable attributes* have semantics defined externally to the operation + itself, but must be compatible with the operations's semantics. These + attributes must have names that start with a dialect prefix. The dialect + indicated by the dialect prefix is expected to verify these attributes. An + example is the `gpu.container_module` attribute. Note that attribute values are allowed to themselves be dictionary attributes, but only the top-level dictionary attribute attached to the operation is subject diff --git a/mlir/docs/Rationale/MLIRForGraphAlgorithms.md b/mlir/docs/Rationale/MLIRForGraphAlgorithms.md --- a/mlir/docs/Rationale/MLIRForGraphAlgorithms.md +++ b/mlir/docs/Rationale/MLIRForGraphAlgorithms.md @@ -8,7 +8,7 @@ fixed in place? This document explains that adoption of MLIR to solve graph based problems -_isn't_ a revolutionary change: it is an incremental series of steps which build +*isn't* a revolutionary change: it is an incremental series of steps which build on each other, each of which delivers local value. This document also addresses some points of confusion that keep coming up. @@ -156,7 +156,7 @@ ```mlir // RUN: mlir-opt %s -canonicalize | FileCheck %s func @test_subi_zero_cfg(%arg0: i32) -> i32 { - %y = subi %arg0, %arg0 : i32 + %y = arith.subi %arg0, %arg0 : i32 return %y: i32 } // CHECK-LABEL: func @test_subi_zero_cfg(%arg0: i32) @@ -210,13 +210,13 @@ ```mlir // RUN: mlir-opt %s -memref-dependence-check -verify-diagnostics func @different_memrefs() { - %m.a = alloc() : memref<100xf32> - %m.b = alloc() : memref<100xf32> - %c0 = constant 0 : index - %c1 = constant 1.0 : f32 - store %c1, %m.a[%c0] : memref<100xf32> + %m.a = memref.alloc() : memref<100xf32> + %m.b = memref.alloc() : memref<100xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1.0 : f32 + memref.store %c1, %m.a[%c0] : memref<100xf32> // expected-note@-1 {{dependence from memref access 0 to access 1 = false}} - %v0 = load %m.b[%c0] : memref<100xf32> + %v0 = memref.load %m.b[%c0] : memref<100xf32> return } ``` @@ -238,8 +238,8 @@ capture this (e.g. serialize it to proto), passes have to recompute it on demand with ShapeRefiner. -The [MLIR Tensor Type](../Dialects/Builtin.md/#rankedtensortype) directly captures shape -information, so you can have things like: +The [MLIR Tensor Type](../Dialects/Builtin.md/#rankedtensortype) directly +captures shape information, so you can have things like: ```mlir %x = tf.Add %x, %y : tensor<128 x 8 x ? x f32> @@ -254,11 +254,11 @@ ### Unified Graph Rewriting Infrastructure This is still a work in progress, but we have sightlines towards a -[general rewriting infrastructure](RationaleGenericDAGRewriter.md) for transforming DAG -tiles into other DAG tiles, using a declarative pattern format. DAG to DAG -rewriting is a generalized solution for many common compiler optimizations, -lowerings, and other rewrites and having an IR enables us to invest in building -a single high-quality implementation. +[general rewriting infrastructure](RationaleGenericDAGRewriter.md) for +transforming DAG tiles into other DAG tiles, using a declarative pattern format. +DAG to DAG rewriting is a generalized solution for many common compiler +optimizations, lowerings, and other rewrites and having an IR enables us to +invest in building a single high-quality implementation. Declarative pattern rules are preferable to imperative C++ code for a number of reasons: they are more compact, easier to reason about, can have checkers diff --git a/mlir/docs/Rationale/Rationale.md b/mlir/docs/Rationale/Rationale.md --- a/mlir/docs/Rationale/Rationale.md +++ b/mlir/docs/Rationale/Rationale.md @@ -58,12 +58,12 @@ Maps, sets, and relations with affine constraints are the core structures underlying a polyhedral representation of high-dimensional loop nests and -multidimensional arrays. These structures are represented as textual -expressions in a form close to their mathematical form. These structures are -used to capture loop nests, tensor data structures, and how they are reordered -and mapped for a target architecture. All structured or "conforming" loops are -captured as part of the polyhedral information, and so are tensor variables, -their layouts, and subscripted accesses to these tensors in memory. +multidimensional arrays. These structures are represented as textual expressions +in a form close to their mathematical form. These structures are used to capture +loop nests, tensor data structures, and how they are reordered and mapped for a +target architecture. All structured or "conforming" loops are captured as part +of the polyhedral information, and so are tensor variables, their layouts, and +subscripted accesses to these tensors in memory. The information captured in the IR allows a compact expression of all loop transformations, data remappings, explicit copying necessary for explicitly @@ -113,17 +113,19 @@ ability to index into the same memref in other ways (something which C arrays allow for example). Furthermore, for the affine constructs, the compiler can follow use-def chains (e.g. through -[affine.apply operations](../Dialects/Affine.md/#affineapply-affineapplyop)) or through -the map attributes of [affine operations](../Dialects/Affine.md/#operations)) to -precisely analyze references at compile-time using polyhedral techniques. This -is possible because of the [restrictions on dimensions and symbols](../Dialects/Affine.md/#restrictions-on-dimensions-and-symbols). +[affine.apply operations](../Dialects/Affine.md/#affineapply-affineapplyop)) or +through the map attributes of +[affine operations](../Dialects/Affine.md/#operations)) to precisely analyze +references at compile-time using polyhedral techniques. This is possible because +of the +[restrictions on dimensions and symbols](../Dialects/Affine.md/#restrictions-on-dimensions-and-symbols). A scalar of element-type (a primitive type or a vector type) that is stored in memory is modeled as a 0-d memref. This is also necessary for scalars that are live out of for loops and if conditionals in a function, for which we don't yet have an SSA representation -- -[an extension](#affineif-and-affinefor-extensions-for-escaping-scalars) to allow that is -described later in this doc. +[an extension](#affineif-and-affinefor-extensions-for-escaping-scalars) to allow +that is described later in this doc. ### Symbols and types @@ -136,7 +138,7 @@ ```mlir func foo(...) { - %A = alloc <8x?xf32, #lmap> (%N) + %A = memref.alloc <8x?xf32, #lmap> (%N) ... call bar(%A) : (memref<8x?xf32, #lmap>) } @@ -145,7 +147,7 @@ // Type of %A indicates that %A has dynamic shape with 8 rows // and unknown number of columns. The number of columns is queried // dynamically using dim instruction. - %N = dim %A, 1 : memref<8x?xf32, #lmap> + %N = memref.dim %A, 1 : memref<8x?xf32, #lmap> affine.for %i = 0 to 8 { affine.for %j = 0 to %N { @@ -167,9 +169,9 @@ ### Block Arguments vs PHI nodes -MLIR Regions represent SSA using "[block arguments](../LangRef.md/#blocks)" rather -than [PHI instructions](http://llvm.org/docs/LangRef.html#i-phi) used in LLVM. -This choice is representationally identical (the same constructs can be +MLIR Regions represent SSA using "[block arguments](../LangRef.md/#blocks)" +rather than [PHI instructions](http://llvm.org/docs/LangRef.html#i-phi) used in +LLVM. This choice is representationally identical (the same constructs can be represented in either form) but block arguments have several advantages: 1. LLVM PHI nodes always have to be kept at the top of a block, and @@ -220,10 +222,10 @@ Data layout information such as the bit width or the alignment of types may be target and ABI-specific and thus should be configurable rather than imposed by the compiler. Especially, the layout of compound or `index` types may vary. MLIR -specifies default bit widths for certain primitive _types_, in particular for +specifies default bit widths for certain primitive *types*, in particular for integers and floats. It is equal to the number that appears in the type definition, e.g. the bit width of `i32` is `32`, so is the bit width of `f32`. -The bit width is not _necessarily_ related to the amount of memory (in bytes) or +The bit width is not *necessarily* related to the amount of memory (in bytes) or the register size (in bits) that is necessary to store the value of the given type. For example, `vector<3xi57>` is likely to be lowered to a vector of four 64-bit integers, so that its storage requirement is `4 x 64 / 8 = 32` bytes, @@ -250,8 +252,9 @@ For the standard dialect, the choice is to have signless integer types. An integer value does not have an intrinsic sign, and it's up to the specific op -for interpretation. For example, ops like `addi` and `muli` do two's complement -arithmetic, but some other operations get a sign, e.g. `divis` vs `diviu`. +for interpretation. For example, ops like `arith.addi` and `arith.muli` do two's +complement arithmetic, but some other operations get a sign, e.g. `arith.divsi` +vs `arith.divui`. LLVM uses the [same design](http://llvm.org/docs/LangRef.html#integer-type), which was introduced in a revamp rolled out @@ -279,11 +282,11 @@ ### Splitting floating point vs integer operations -The MLIR "standard" operation set splits many integer and floating point -operations into different categories, for example `addf` vs `addi` and `cmpf` vs -`cmpi` +The MLIR "Arithmetic" dialect splits many integer and floating point operations +into different categories, for example `arith.addf` vs `arith.addi` and +`arith.cmpf` vs `arith.cmpi` ([following the design of LLVM](http://llvm.org/docs/LangRef.html#binary-operations)). -These instructions _are_ polymorphic on the number of elements in the type +These instructions *are* polymorphic on the number of elements in the type though, for example `addf` is used with scalar floats, vectors of floats, and tensors of floats (LLVM does the same thing with its scalar/vector types). @@ -308,12 +311,12 @@ ### Specifying sign in integer comparison operations -Since integers are [signless](#integer-signedness-semantics), it is necessary to define the -sign for integer comparison operations. This sign indicates how to treat the -foremost bit of the integer: as sign bit or as most significant bit. For -example, comparing two `i4` values `0b1000` and `0b0010` yields different +Since integers are [signless](#integer-signedness-semantics), it is necessary to +define the sign for integer comparison operations. This sign indicates how to +treat the foremost bit of the integer: as sign bit or as most significant bit. +For example, comparing two `i4` values `0b1000` and `0b0010` yields different results for unsigned (`8 > 3`) and signed (`-8 < 3`) interpretations. This -difference is only significant for _order_ comparisons, but not for _equality_ +difference is only significant for *order* comparisons, but not for *equality* comparisons. Indeed, for the latter all bits must have the same value independently of the sign. Since both arguments have exactly the same bit width and cannot be padded by this operation, it is impossible to compare two values @@ -491,10 +494,10 @@ ### Tuple types The MLIR type system provides first class support for defining -[tuple types](../Dialects/Builtin/#tupletype). This is due to the fact that `Tuple` -represents a universal concept that is likely to, and has already begun to, -present itself in many different dialects. Though this type is first class in -the type system, it merely serves to provide a common mechanism in which to +[tuple types](../Dialects/Builtin/#tupletype). This is due to the fact that +`Tuple` represents a universal concept that is likely to, and has already begun +to, present itself in many different dialects. Though this type is first class +in the type system, it merely serves to provide a common mechanism in which to represent this concept in MLIR. As such, MLIR provides no standard operations for interfacing with `tuple` types. It is up to dialect authors to provide operations, e.g. extract_tuple_element, to interpret and manipulate them. When @@ -547,7 +550,7 @@ ```mlir func @search(%A: memref, %S: , %key : i32) { - %ni = dim %A, 0 : memref + %ni = memref.dim %A, 0 : memref // This loop can be parallelized affine.for %i = 0 to %ni { call @search_body (%A, %S, %key, %i) : (memref, memref, i32, i32) @@ -556,16 +559,16 @@ } func @search_body(%A: memref, %S: memref, %key: i32, %i : i32) { - %nj = dim %A, 1 : memref + %nj = memref.dim %A, 1 : memref br ^bb1(0) ^bb1(%j: i32) - %p1 = cmpi "lt", %j, %nj : i32 + %p1 = arith.cmpi "lt", %j, %nj : i32 cond_br %p1, ^bb2, ^bb5 ^bb2: %v = affine.load %A[%i, %j] : memref - %p2 = cmpi "eq", %v, %key : i32 + %p2 = arith.cmpi "eq", %v, %key : i32 cond_br %p2, ^bb3(%j), ^bb4 ^bb3(%j: i32) @@ -573,7 +576,7 @@ br ^bb5 ^bb4: - %jinc = addi %j, 1 : i32 + %jinc = arith.addi %j, 1 : i32 br ^bb1(%jinc) ^bb5: @@ -728,10 +731,10 @@ explicitly propagate the schedule into domains and model all the cleanup code. An example and more detail on the schedule tree form is in the next section. -1. Having two different forms of "affine regions": an affine loop tree form - and a polyhedral schedule tree form. In the latter, ops could carry - attributes capturing domain, scheduling, and other polyhedral code - generation options with IntegerSet, AffineMap, and other attributes. +1. Having two different forms of "affine regions": an affine loop tree form and + a polyhedral schedule tree form. In the latter, ops could carry attributes + capturing domain, scheduling, and other polyhedral code generation options + with IntegerSet, AffineMap, and other attributes. #### Schedule Tree Representation for Affine Regions @@ -788,12 +791,11 @@ ### Affine Relations -The current MLIR spec includes affine maps and integer sets, but not -affine relations. Affine relations are a natural way to model read and -write access information, which can be very useful to capture the -behavior of external library calls where no implementation is -available, high-performance vendor libraries, or user-provided / -user-tuned routines. +The current MLIR spec includes affine maps and integer sets, but not affine +relations. Affine relations are a natural way to model read and write access +information, which can be very useful to capture the behavior of external +library calls where no implementation is available, high-performance vendor +libraries, or user-provided / user-tuned routines. An affine relation is a relation between input and output dimension identifiers while being symbolic on a list of symbolic identifiers and with affine @@ -844,7 +846,7 @@ bb0 (%0, %1: memref<128xf32>, i64): %val = affine.load %A [%pos] %val = affine.load %A [%pos + 1] - %p = mulf %val, %val : f32 + %p = arith.mulf %val, %val : f32 return %p : f32 } ``` diff --git a/mlir/docs/SPIRVToLLVMDialectConversion.md b/mlir/docs/SPIRVToLLVMDialectConversion.md --- a/mlir/docs/SPIRVToLLVMDialectConversion.md +++ b/mlir/docs/SPIRVToLLVMDialectConversion.md @@ -58,21 +58,21 @@ strides (based on [`VulkanLayoutUtils`][VulkanLayoutUtils]) are supported. They are also mapped to LLVM array. -SPIR-V Dialect | LLVM Dialect -:-----------------------------------: | :-----------------------------------: -`!spv.array< x >`| `!llvm.array< x >` -`!spv.rtarray< >` | `!llvm.array<0 x >` +SPIR-V Dialect | LLVM Dialect +:------------------------------------: | :-------------------------------------: +`!spv.array< x >` | `!llvm.array< x >` +`!spv.rtarray< >` | `!llvm.array<0 x >` ### Struct types Members of SPIR-V struct types may have decorations and offset information. Currently, there is **no** support of member decorations conversion for structs. -For more information see section on [Decorations](#Decorations-conversion). +For more information see section on [Decorations](#Decorations-conversion). Usually we expect that each struct member has a natural size and alignment. -However, there are cases (*e.g.* in graphics) where one would place struct -members explicitly at particular offsets. This case is **not** supported -at the moment. Hence, we adhere to the following mapping: +However, there are cases (*e.g.* in graphics) where one would place struct +members explicitly at particular offsets. This case is **not** supported at the +moment. Hence, we adhere to the following mapping: * Structs with no offset are modelled as LLVM packed structures. @@ -86,14 +86,11 @@ a design would require index recalculation in the conversion of ops that involve memory addressing. -Examples of SPIR-V struct conversion are: -```mlir -!spv.struct => !llvm.struct -!spv.struct => !llvm.struct<(i8, i32)> +Examples of SPIR-V struct conversion are: ```mlir !spv.struct => +!llvm.struct !spv.struct => !llvm.struct<(i8, +i32)> -// error -!spv.struct -``` +// error !spv.struct ``` ### Not implemented types @@ -104,10 +101,10 @@ This section describes how SPIR-V Dialect operations are converted to LLVM Dialect. It lists already working conversion patterns, as well as those that are -an ongoing work. +an ongoing work. There are also multiple ops for which there is no clear mapping in LLVM. -Conversion for those have to be discussed within the community on the +Conversion for those have to be discussed within the community on the case-by-case basis. ### Arithmetic ops @@ -115,21 +112,21 @@ SPIR-V arithmetic ops mostly have a direct equivalent in LLVM Dialect. Such exceptions as `spv.SMod` and `spv.FMod` are rare. -SPIR-V Dialect op | LLVM Dialect op -:-----------------------------------: | :-----------------------------------: -`spv.FAdd` | `llvm.fadd` -`spv.FDiv` | `llvm.fdiv` -`spv.FNegate` | `llvm.fneg` -`spv.FMul` | `llvm.fmul` -`spv.FRem` | `llvm.frem` -`spv.FSub` | `llvm.fsub` -`spv.IAdd` | `llvm.add` -`spv.IMul` | `llvm.mul` -`spv.ISub` | `llvm.sub` -`spv.SDiv` | `llvm.sdiv` -`spv.SRem` | `llvm.srem` -`spv.UDiv` | `llvm.udiv` -`spv.UMod` | `llvm.urem` +SPIR-V Dialect op | LLVM Dialect op +:---------------: | :-------------: +`spv.FAdd` | `llvm.fadd` +`spv.FDiv` | `llvm.fdiv` +`spv.FNegate` | `llvm.fneg` +`spv.FMul` | `llvm.fmul` +`spv.FRem` | `llvm.frem` +`spv.FSub` | `llvm.fsub` +`spv.IAdd` | `llvm.add` +`spv.IMul` | `llvm.mul` +`spv.ISub` | `llvm.sub` +`spv.SDiv` | `llvm.sdiv` +`spv.SRem` | `llvm.srem` +`spv.UDiv` | `llvm.udiv` +`spv.UMod` | `llvm.urem` ### Bitwise ops @@ -141,18 +138,18 @@ As with arithmetic ops, most of bitwise ops have a semantically equivalent op in LLVM: -SPIR-V Dialect op | LLVM Dialect op -:-----------------------------------: | :-----------------------------------: -`spv.BitwiseAnd` | `llvm.and` -`spv.BitwiseOr` | `llvm.or` -`spv.BitwiseXor` | `llvm.xor` +SPIR-V Dialect op | LLVM Dialect op +:---------------: | :-------------: +`spv.BitwiseAnd` | `llvm.and` +`spv.BitwiseOr` | `llvm.or` +`spv.BitwiseXor` | `llvm.xor` Also, some of bitwise ops can be modelled with LLVM intrinsics: -SPIR-V Dialect op | LLVM Dialect intrinsic -:-----------------------------------: | :-----------------------------------: -`spv.BitCount` | `llvm.intr.ctpop` -`spv.BitReverse` | `llvm.intr.bitreverse` +SPIR-V Dialect op | LLVM Dialect intrinsic +:---------------: | :--------------------: +`spv.BitCount` | `llvm.intr.ctpop` +`spv.BitReverse` | `llvm.intr.bitreverse` #### `spv.Not` @@ -170,9 +167,8 @@ outline the general design of conversion patterns for this ops, and then describe each of them. -All of these ops take `base`, `offset` and `count` (`insert` for -`spv.BitFieldInsert`) as arguments. There are two important things -to note: +All of these ops take `base`, `offset` and `count` (`insert` for +`spv.BitFieldInsert`) as arguments. There are two important things to note: * `offset` and `count` are always scalar. This means that we can have the following case: @@ -220,10 +216,9 @@ ##### `spv.BitFieldInsert` This operation is implemented as a series of LLVM Dialect operations. First step -would be to create a mask with bits set outside -[`offset`, `offset` + `count` - 1]. Then, unchanged bits are extracted from -`base` that are outside of [`offset`, `offset` + `count` - 1]. The result is -`or`ed with shifted `insert`. +would be to create a mask with bits set outside [`offset`, `offset` + `count` - +1]. Then, unchanged bits are extracted from `base` that are outside of +[`offset`, `offset` + `count` - 1]. The result is `or`ed with shifted `insert`. ```mlir // Create mask @@ -284,73 +279,79 @@ #### Direct conversions -SPIR-V Dialect op | LLVM Dialect op -:-----------------------------------: | :-----------------------------------: -`spv.ConvertFToS` | `llvm.fptosi` -`spv.ConvertFToU` | `llvm.fptoui` -`spv.ConvertSToF` | `llvm.sitofp` -`spv.ConvertUToF` | `llvm.uitofp` +SPIR-V Dialect op | LLVM Dialect op +:---------------: | :-------------: +`spv.ConvertFToS` | `llvm.fptosi` +`spv.ConvertFToU` | `llvm.fptoui` +`spv.ConvertSToF` | `llvm.sitofp` +`spv.ConvertUToF` | `llvm.uitofp` #### spv.Bitcast + This operation has a direct counterpart in LLVM: `llvm.bitcast`. It is treated separately since it also supports pointer to pointer bit pattern-preserving type conversion, apart from regular scalar or vector of numerical type. #### Special cases + Special cases include `spv.FConvert`, `spv.SConvert` and `spv.UConvert`. These operations are either a truncate or extend. Let's denote the operand component width as A, and result component width as R. Then, the following mappings are used: -##### `spv.FConvert` -Case | LLVM Dialect op -:-------------: | :-----------------------------------: -A < R | `llvm.fpext` -A > R | `llvm.fptrunc` +##### `spv.FConvert` + +Case | LLVM Dialect op +:---: | :-------------: +A < R | `llvm.fpext` +A > R | `llvm.fptrunc` -##### `spv.SConvert` -Case | LLVM Dialect op -:-------------: | :-----------------------------------: -A < R | `llvm.sext` -A > R | `llvm.trunc` +##### `spv.SConvert` -##### `spv.UConvert` -Case | LLVM Dialect op -:-------------: | :-----------------------------------: -A < R | `llvm.zext` -A > R | `llvm.trunc` +Case | LLVM Dialect op +:---: | :-------------: +A < R | `llvm.sext` +A > R | `llvm.trunc` + +##### `spv.UConvert` + +Case | LLVM Dialect op +:---: | :-------------: +A < R | `llvm.zext` +A > R | `llvm.trunc` The case when A = R is not possible, based on SPIR-V Dialect specification: + > The component width cannot equal the component width in Result Type. ### Comparison ops SPIR-V comparison ops are mapped to LLVM `icmp` and `fcmp` operations. -SPIR-V Dialect op | LLVM Dialect op -:-----------------------------------: | :-----------------------------------: -`spv.IEqual` | `llvm.icmp "eq"` -`spv.INotEqual` | `llvm.icmp "ne"` -`spv.FOrdEqual` | `llvm.fcmp "oeq"` -`spv.FOrdGreaterThan` | `llvm.fcmp "ogt"` -`spv.FOrdGreaterThanEqual` | `llvm.fcmp "oge"` -`spv.FOrdLessThan` | `llvm.fcmp "olt"` -`spv.FOrdLessThanEqual` | `llvm.fcmp "ole"` -`spv.FOrdNotEqual` | `llvm.fcmp "one"` -`spv.FUnordEqual` | `llvm.fcmp "ueq"` -`spv.FUnordGreaterThan` | `llvm.fcmp "ugt"` -`spv.FUnordGreaterThanEqual` | `llvm.fcmp "uge"` -`spv.FUnordLessThan` | `llvm.fcmp "ult"` -`spv.FUnordLessThanEqual` | `llvm.fcmp "ule"` -`spv.FUnordNotEqual` | `llvm.fcmp "une"` -`spv.SGreaterThan` | `llvm.icmp "sgt"` -`spv.SGreaterThanEqual` | `llvm.icmp "sge"` -`spv.SLessThan` | `llvm.icmp "slt"` -`spv.SLessThanEqual` | `llvm.icmp "sle"` -`spv.UGreaterThan` | `llvm.icmp "ugt"` -`spv.UGreaterThanEqual` | `llvm.icmp "uge"` -`spv.ULessThan` | `llvm.icmp "ult"` -`spv.ULessThanEqual` | `llvm.icmp "ule"` +SPIR-V Dialect op | LLVM Dialect op +:--------------------------: | :---------------: +`spv.IEqual` | `llvm.icmp "eq"` +`spv.INotEqual` | `llvm.icmp "ne"` +`spv.FOrdEqual` | `llvm.fcmp "oeq"` +`spv.FOrdGreaterThan` | `llvm.fcmp "ogt"` +`spv.FOrdGreaterThanEqual` | `llvm.fcmp "oge"` +`spv.FOrdLessThan` | `llvm.fcmp "olt"` +`spv.FOrdLessThanEqual` | `llvm.fcmp "ole"` +`spv.FOrdNotEqual` | `llvm.fcmp "one"` +`spv.FUnordEqual` | `llvm.fcmp "ueq"` +`spv.FUnordGreaterThan` | `llvm.fcmp "ugt"` +`spv.FUnordGreaterThanEqual` | `llvm.fcmp "uge"` +`spv.FUnordLessThan` | `llvm.fcmp "ult"` +`spv.FUnordLessThanEqual` | `llvm.fcmp "ule"` +`spv.FUnordNotEqual` | `llvm.fcmp "une"` +`spv.SGreaterThan` | `llvm.icmp "sgt"` +`spv.SGreaterThanEqual` | `llvm.icmp "sge"` +`spv.SLessThan` | `llvm.icmp "slt"` +`spv.SLessThanEqual` | `llvm.icmp "sle"` +`spv.UGreaterThan` | `llvm.icmp "ugt"` +`spv.UGreaterThanEqual` | `llvm.icmp "uge"` +`spv.ULessThan` | `llvm.icmp "ult"` +`spv.ULessThanEqual` | `llvm.icmp "ule"` ### Composite ops @@ -359,12 +360,12 @@ composite object is a vector, and when the composite object is of a non-vector type (*i.e.* struct, array or runtime array). -Composite type | SPIR-V Dialect op | LLVM Dialect op -:-------------: | :--------------------: | :--------------------: -vector | `spv.CompositeExtract` | `llvm.extractelement` -vector | `spv.CompositeInsert` | `llvm.insertelement` -non-vector | `spv.CompositeExtract` | `llvm.extractvalue` -non-vector | `spv.CompositeInsert` | `llvm.insertvalue` +Composite type | SPIR-V Dialect op | LLVM Dialect op +:------------: | :--------------------: | :-------------------: +vector | `spv.CompositeExtract` | `llvm.extractelement` +vector | `spv.CompositeInsert` | `llvm.insertelement` +non-vector | `spv.CompositeExtract` | `llvm.extractvalue` +non-vector | `spv.CompositeInsert` | `llvm.insertvalue` ### `spv.EntryPoint` and `spv.ExecutionMode` @@ -381,7 +382,7 @@ struct global variable that stores the execution mode id and any variables associated with it. In C, the struct has the structure shown below. - ```C + ```c // No values are associated // There are values that are associated // with this entry point. // with this entry point. struct { struct { @@ -406,12 +407,12 @@ they operate on `i1` or vector of `i1` values. The following mapping is used to emulate SPIR-V ops behaviour: -SPIR-V Dialect op | LLVM Dialect op -:-----------------------------------: | :-----------------------------------: -`spv.LogicalAnd` | `llvm.and` -`spv.LogicalOr` | `llvm.or` -`spv.LogicalEqual` | `llvm.icmp "eq"` -`spv.LogicalNotEqual` | `llvm.icmp "ne"` +SPIR-V Dialect op | LLVM Dialect op +:-------------------: | :--------------: +`spv.LogicalAnd` | `llvm.and` +`spv.LogicalOr` | `llvm.or` +`spv.LogicalEqual` | `llvm.icmp "eq"` +`spv.LogicalNotEqual` | `llvm.icmp "ne"` `spv.LogicalNot` has the same conversion pattern as bitwise `spv.Not`. It is modelled with `xor` operation with a mask with all bits set. @@ -468,13 +469,13 @@ #### `spv.GlobalVariable` and `spv.mlir.addressof` -`spv.GlobalVariable` is modelled with `llvm.mlir.global` op. However, there -is a difference that has to be pointed out. +`spv.GlobalVariable` is modelled with `llvm.mlir.global` op. However, there is a +difference that has to be pointed out. In SPIR-V dialect, the global variable returns a pointer, whereas in LLVM dialect the global holds an actual value. This difference is handled by -`spv.mlir.addressof` and `llvm.mlir.addressof` ops that both return a pointer and -are used to reference the global. +`spv.mlir.addressof` and `llvm.mlir.addressof` ops that both return a pointer +and are used to reference the global. ```mlir // Original SPIR-V module @@ -496,9 +497,9 @@ } ``` -The SPIR-V to LLVM conversion does not involve modelling of workgroups. -Hence, we say that only current invocation is in conversion's scope. This means -that global variables with pointers of `Input`, `Output`, and `Private` storage +The SPIR-V to LLVM conversion does not involve modelling of workgroups. Hence, +we say that only current invocation is in conversion's scope. This means that +global variables with pointers of `Input`, `Output`, and `Private` storage classes are supported. Also, `StorageBuffer` storage class is allowed for executing [`mlir-spirv-cpu-runner`](#mlir-spirv-cpu-runner). @@ -510,8 +511,8 @@ storage class and `External` for other storage classes, based on SPIR-V spec: > By default, functions and global variables are private to a module and cannot -be accessed by other modules. However, a module may be written to export or -import functions and global (module scope) variables. +> be accessed by other modules. However, a module may be written to export or +> import functions and global (module scope) variables. If the global variable's pointer has `Input` storage class, then a `constant` flag is added to LLVM op: @@ -554,10 +555,10 @@ converted directly to LLVM dialect. Their conversion is addressed in this section. -SPIR-V Dialect op | LLVM Dialect op -:-----------------------------------: | :-----------------------------------: -`spv.Select` | `llvm.select` -`spv.Undef` | `llvm.mlir.undef` +SPIR-V Dialect op | LLVM Dialect op +:---------------: | :---------------: +`spv.Select` | `llvm.select` +`spv.Undef` | `llvm.mlir.undef` ### Shift ops @@ -665,10 +666,10 @@ ### `spv.mlir.selection` and `spv.mlir.loop` -Control flow within `spv.mlir.selection` and `spv.mlir.loop` is lowered directly to LLVM -via branch ops. The conversion can only be applied to selection or loop with all -blocks being reachable. Moreover, selection and loop control attributes (such as -`Flatten` or `Unroll`) are not supported at the moment. +Control flow within `spv.mlir.selection` and `spv.mlir.loop` is lowered directly +to LLVM via branch ops. The conversion can only be applied to selection or loop +with all blocks being reachable. Moreover, selection and loop control attributes +(such as `Flatten` or `Unroll`) are not supported at the moment. ```mlir // Conversion of selection @@ -727,20 +728,20 @@ ### Direct conversions -SPIR-V Dialect op | LLVM Dialect op -:-----------------------------------: | :-----------------------------------: -`spv.GLSL.Ceil` | `llvm.intr.ceil` -`spv.GLSL.Cos` | `llvm.intr.cos` -`spv.GLSL.Exp` | `llvm.intr.exp` -`spv.GLSL.FAbs` | `llvm.intr.fabs` -`spv.GLSL.Floor` | `llvm.intr.floor` -`spv.GLSL.FMax` | `llvm.intr.maxnum` -`spv.GLSL.FMin` | `llvm.intr.minnum` -`spv.GLSL.Log` | `llvm.intr.log` -`spv.GLSL.Sin` | `llvm.intr.sin` -`spv.GLSL.Sqrt` | `llvm.intr.sqrt` -`spv.GLSL.SMax` | `llvm.intr.smax` -`spv.GLSL.SMin` | `llvm.intr.smin` +SPIR-V Dialect op | LLVM Dialect op +:---------------: | :----------------: +`spv.GLSL.Ceil` | `llvm.intr.ceil` +`spv.GLSL.Cos` | `llvm.intr.cos` +`spv.GLSL.Exp` | `llvm.intr.exp` +`spv.GLSL.FAbs` | `llvm.intr.fabs` +`spv.GLSL.Floor` | `llvm.intr.floor` +`spv.GLSL.FMax` | `llvm.intr.maxnum` +`spv.GLSL.FMin` | `llvm.intr.minnum` +`spv.GLSL.Log` | `llvm.intr.log` +`spv.GLSL.Sin` | `llvm.intr.sin` +`spv.GLSL.Sqrt` | `llvm.intr.sqrt` +`spv.GLSL.SMax` | `llvm.intr.smax` +`spv.GLSL.SMin` | `llvm.intr.smin` ### Special cases @@ -760,7 +761,8 @@ %res = fdiv %sin, %cos : f32 ``` -`spv.Tanh` is modelled using the equality `tanh(x) = {exp(2x) - 1}/{exp(2x) + 1}`: +`spv.Tanh` is modelled using the equality `tanh(x) = {exp(2x) - 1}/{exp(2x) + +1}`: ```mlir %two = llvm.mlir.constant(2.0: f32) : f32 @@ -778,20 +780,23 @@ to LLVM dialect. ### `spv.func` -This op declares or defines a SPIR-V function and it is converted to `llvm.func`. -This conversion handles signature conversion, and function control attributes -remapping to LLVM dialect function [`passthrough` attribute](Dialects/LLVM.md/#attribute-pass-through). -The following mapping is used to map [SPIR-V function control][SPIRVFunctionAttributes] to +This op declares or defines a SPIR-V function and it is converted to +`llvm.func`. This conversion handles signature conversion, and function control +attributes remapping to LLVM dialect function +[`passthrough` attribute](Dialects/LLVM.md/#attribute-pass-through). + +The following mapping is used to map +[SPIR-V function control][SPIRVFunctionAttributes] to [LLVM function attributes][LLVMFunctionAttributes]: -SPIR-V Function Control Attributes | LLVM Function Attributes -:-----------------------------------: | :-----------------------------------: -None | No function attributes passed -Inline | `alwaysinline` -DontInline | `noinline` -Pure | `readonly` -Const | `readnone` +SPIR-V Function Control Attributes | LLVM Function Attributes +:--------------------------------: | :---------------------------: +None | No function attributes passed +Inline | `alwaysinline` +DontInline | `noinline` +Pure | `readonly` +Const | `readnone` ### `spv.Return` and `spv.ReturnValue` @@ -816,10 +821,8 @@ SPIR-V to LLVM dialect conversion. Currently, only single-threaded kernel is supported. -To build the runner, add the following option to `cmake`: -```bash --DMLIR_ENABLE_SPIRV_CPU_RUNNER=1 -``` +To build the runner, add the following option to `cmake`: `bash +-DMLIR_ENABLE_SPIRV_CPU_RUNNER=1` ### Pipeline @@ -857,7 +860,7 @@ func @main() { // Fill the buffer with some data - %buffer = alloc : memref<8xi32> + %buffer = memref.alloc : memref<8xi32> %data = ... call fillBuffer(%buffer, %data) @@ -880,7 +883,7 @@ func @main() { // Fill the buffer with some data. - %buffer = alloc : memref<8xi32> + %buffer = memref.alloc : memref<8xi32> %data = ... call fillBuffer(%buffer, %data) diff --git a/mlir/docs/SymbolsAndSymbolTables.md b/mlir/docs/SymbolsAndSymbolTables.md --- a/mlir/docs/SymbolsAndSymbolTables.md +++ b/mlir/docs/SymbolsAndSymbolTables.md @@ -2,11 +2,11 @@ [TOC] -With [Regions](LangRef.md/#regions), the multi-level aspect of MLIR is structural -in the IR. A lot of infrastructure within the compiler is built around this -nesting structure; including the processing of operations within the -[pass manager](PassManagement.md/#pass-manager). One advantage of the MLIR design -is that it is able to process operations in parallel, utilizing multiple +With [Regions](LangRef.md/#regions), the multi-level aspect of MLIR is +structural in the IR. A lot of infrastructure within the compiler is built +around this nesting structure; including the processing of operations within the +[pass manager](PassManagement.md/#pass-manager). One advantage of the MLIR +design is that it is able to process operations in parallel, utilizing multiple threads. This is possible due to a property of the IR known as [`IsolatedFromAbove`](Traits.md/#isolatedfromabove). @@ -137,13 +137,13 @@ different trade offs depending on the situation. A function call may directly use a `SymbolRef` as the callee, whereas a reference to a global variable might use a materialization operation so that the variable can be used in other -operations like `std.addi`. -[`llvm.mlir.addressof`](Dialects/LLVM.md/#llvmmliraddressof-mlirllvmaddressofop) is one example of -such an operation. +operations like `arith.addi`. +[`llvm.mlir.addressof`](Dialects/LLVM.md/#llvmmliraddressof-mlirllvmaddressofop) +is one example of such an operation. See the `LangRef` definition of the -[`SymbolRefAttr`](Dialects/Builtin.md/#symbolrefattr) for more information -about the structure of this attribute. +[`SymbolRefAttr`](Dialects/Builtin.md/#symbolrefattr) for more information about +the structure of this attribute. Operations that reference a `Symbol` and want to perform verification and general mutation of the symbol should implement the `SymbolUserOpInterface` to diff --git a/mlir/docs/TargetLLVMIR.md b/mlir/docs/TargetLLVMIR.md --- a/mlir/docs/TargetLLVMIR.md +++ b/mlir/docs/TargetLLVMIR.md @@ -305,8 +305,8 @@ return %arg0, %arg1 : i32, i64 } func @bar() { - %0 = constant 42 : i32 - %1 = constant 17 : i64 + %0 = arith.constant 42 : i32 + %1 = arith.constant 17 : i64 %2:2 = call @foo(%0, %1) : (i32, i64) -> (i32, i64) "use_i32"(%2#0) : (i32) -> () "use_i64"(%2#1) : (i64) -> () @@ -768,7 +768,7 @@ An access to a memref with indices: ```mlir -%0 = load %m[%1,%2,%3,%4] : memref +%0 = memref.load %m[%1,%2,%3,%4] : memref ``` is transformed into the equivalent of the following code: @@ -779,27 +779,27 @@ // dynamic, extract the stride value from the descriptor. %stride1 = llvm.extractvalue[4, 0] : !llvm.struct<(ptr, ptr, i64, array<4xi64>, array<4xi64>)> -%addr1 = muli %stride1, %1 : i64 +%addr1 = arith.muli %stride1, %1 : i64 // When the stride or, in absence of explicit strides, the trailing sizes are // known statically, this value is used as a constant. The natural value of // strides is the product of all sizes following the current dimension. %stride2 = llvm.mlir.constant(32 : index) : i64 -%addr2 = muli %stride2, %2 : i64 -%addr3 = addi %addr1, %addr2 : i64 +%addr2 = arith.muli %stride2, %2 : i64 +%addr3 = arith.addi %addr1, %addr2 : i64 %stride3 = llvm.mlir.constant(8 : index) : i64 -%addr4 = muli %stride3, %3 : i64 -%addr5 = addi %addr3, %addr4 : i64 +%addr4 = arith.muli %stride3, %3 : i64 +%addr5 = arith.addi %addr3, %addr4 : i64 // Multiplication with the known unit stride can be omitted. -%addr6 = addi %addr5, %4 : i64 +%addr6 = arith.addi %addr5, %4 : i64 // If the linear offset is known to be zero, it can also be omitted. If it is // dynamic, it is extracted from the descriptor. %offset = llvm.extractvalue[2] : !llvm.struct<(ptr, ptr, i64, array<4xi64>, array<4xi64>)> -%addr7 = addi %addr6, %offset : i64 +%addr7 = arith.addi %addr6, %offset : i64 // All accesses are based on the aligned pointer. %aligned = llvm.extractvalue[1] : !llvm.struct<(ptr, ptr, i64, diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md --- a/mlir/docs/Traits.md +++ b/mlir/docs/Traits.md @@ -56,13 +56,12 @@ `verifyTrait` hook out-of-line as a free function when possible to avoid instantiating the implementation for every concrete operation type. -Operation traits may also provide a `foldTrait` hook that is called when -folding the concrete operation. The trait folders will only be invoked if -the concrete operation fold is either not implemented, fails, or performs -an in-place fold. +Operation traits may also provide a `foldTrait` hook that is called when folding +the concrete operation. The trait folders will only be invoked if the concrete +operation fold is either not implemented, fails, or performs an in-place fold. -The following signature of fold will be called if it is implemented -and the op has a single result. +The following signature of fold will be called if it is implemented and the op +has a single result. ```c++ template @@ -76,8 +75,8 @@ }; ``` -Otherwise, if the operation has a single result and the above signature is -not implemented, or the operation has multiple results, then the following signature +Otherwise, if the operation has a single result and the above signature is not +implemented, or the operation has multiple results, then the following signature will be used (if implemented): ```c++ @@ -200,9 +199,9 @@ such operations automatically become valid symbols for the polyhedral scope defined by that operation. As a result, such SSA values could be used as the operands or index operands of various affine dialect operations like affine.for, -affine.load, and affine.store. The polyhedral scope defined by an operation -with this trait includes all operations in its region excluding operations that -are nested inside of other operations that themselves have this trait. +affine.load, and affine.store. The polyhedral scope defined by an operation with +this trait includes all operations in its region excluding operations that are +nested inside of other operations that themselves have this trait. ### AutomaticAllocationScope @@ -211,7 +210,8 @@ This trait is carried by region holding operations that define a new scope for automatic allocation. Such allocations are automatically freed when control is transferred back from the regions of such operations. As an example, allocations -performed by [`memref.alloca`](Dialects/MemRef.md/#memrefalloca-mlirmemrefallocaop) are +performed by +[`memref.alloca`](Dialects/MemRef.md/#memrefalloca-mlirmemrefallocaop) are automatically freed when control leaves the region of its closest surrounding op that has the trait AutomaticAllocationScope. @@ -241,7 +241,7 @@ ### ElementwiseMappable -* `OpTrait::ElementwiseMappable` -- `ElementwiseMappable` +* `OpTrait::ElementwiseMappable` -- `ElementwiseMappable` This trait tags scalar ops that also can be applied to vectors/tensors, with their semantics on vectors/tensors being elementwise application. This trait @@ -300,7 +300,7 @@ `IsolatedFromAbove`: ```mlir -%result = constant 10 : i32 +%result = arith.constant 10 : i32 foo.region_op { foo.yield %result : i32 } @@ -311,14 +311,13 @@ ### MemRefsNormalizable -* `OpTrait::MemRefsNormalizable` -- `MemRefsNormalizable` +* `OpTrait::MemRefsNormalizable` -- `MemRefsNormalizable` -This trait is used to flag operations that consume or produce -values of `MemRef` type where those references can be 'normalized'. -In cases where an associated `MemRef` has a -non-identity memory-layout specification, such normalizable operations can be -modified so that the `MemRef` has an identity layout specification. -This can be implemented by associating the operation with its own +This trait is used to flag operations that consume or produce values of `MemRef` +type where those references can be 'normalized'. In cases where an associated +`MemRef` has a non-identity memory-layout specification, such normalizable +operations can be modified so that the `MemRef` has an identity layout +specification. This can be implemented by associating the operation with its own index expression that can express the equivalent of the memory-layout specification of the MemRef type. See [the -normalize-memrefs pass]. (https://mlir.llvm.org/docs/Passes/#-normalize-memrefs-normalize-memrefs) diff --git a/mlir/docs/Tutorials/Toy/Ch-5.md b/mlir/docs/Tutorials/Toy/Ch-5.md --- a/mlir/docs/Tutorials/Toy/Ch-5.md +++ b/mlir/docs/Tutorials/Toy/Ch-5.md @@ -15,20 +15,20 @@ `Affine` for the computation heavy part of Toy, and in the [next chapter](Ch-6.md) directly target the `LLVM IR` dialect for lowering `print`. As part of this lowering, we will be lowering from the -[TensorType](../../Dialects/Builtin.md/#rankedtensortype) that `Toy` -operates on to the [MemRefType](../../Dialects/Builtin.md/#memreftype) that is -indexed via an affine loop-nest. Tensors represent an abstract value-typed -sequence of data, meaning that they don't live in any memory. MemRefs, on the -other hand, represent lower level buffer access, as they are concrete -references to a region of memory. +[TensorType](../../Dialects/Builtin.md/#rankedtensortype) that `Toy` operates on +to the [MemRefType](../../Dialects/Builtin.md/#memreftype) that is indexed via +an affine loop-nest. Tensors represent an abstract value-typed sequence of data, +meaning that they don't live in any memory. MemRefs, on the other hand, +represent lower level buffer access, as they are concrete references to a region +of memory. # Dialect Conversions MLIR has many different dialects, so it is important to have a unified framework -for [converting](../../../getting_started/Glossary.md/#conversion) between them. This is where the -`DialectConversion` framework comes into play. This framework allows for -transforming a set of *illegal* operations to a set of *legal* ones. To use this -framework, we need to provide two things (and an optional third): +for [converting](../../../getting_started/Glossary.md/#conversion) between them. +This is where the `DialectConversion` framework comes into play. This framework +allows for transforming a set of *illegal* operations to a set of *legal* ones. +To use this framework, we need to provide two things (and an optional third): * A [Conversion Target](../../DialectConversion.md/#conversion-target) @@ -40,8 +40,8 @@ * A set of [Rewrite Patterns](../../DialectConversion.md/#rewrite-pattern-specification) - - This is the set of [patterns](../QuickstartRewrites.md) used to - convert *illegal* operations into a set of zero or more *legal* ones. + - This is the set of [patterns](../QuickstartRewrites.md) used to convert + *illegal* operations into a set of zero or more *legal* ones. * Optionally, a [Type Converter](../../DialectConversion.md/#type-conversion). @@ -63,9 +63,9 @@ // We define the specific operations, or dialects, that are legal targets for // this lowering. In our case, we are lowering to a combination of the - // `Affine`, `MemRef` and `Standard` dialects. - target.addLegalDialect(); + // `Affine`, `Arithmetic`, `MemRef`, and `Standard` dialects. + target.addLegalDialect(); // We also define the Toy dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. Given that we actually want @@ -77,11 +77,10 @@ } ``` -Above, we first set the toy dialect to illegal, and then the print operation -as legal. We could have done this the other way around. -Individual operations always take precedence over the (more generic) dialect -definitions, so the order doesn't matter. See `ConversionTarget::getOpInfo` -for the details. +Above, we first set the toy dialect to illegal, and then the print operation as +legal. We could have done this the other way around. Individual operations +always take precedence over the (more generic) dialect definitions, so the order +doesn't matter. See `ConversionTarget::getOpInfo` for the details. ## Conversion Patterns @@ -97,9 +96,9 @@ remapped/replaced. This is used when dealing with type conversions, as the pattern will want to operate on values of the new type but match against the old. For our lowering, this invariant will be useful as it translates from the -[TensorType](../../Dialects/Builtin.md/#rankedtensortype) currently -being operated on to the [MemRefType](../../Dialects/Builtin.md/#memreftype). -Let's look at a snippet of lowering the `toy.transpose` operation: +[TensorType](../../Dialects/Builtin.md/#rankedtensortype) currently being +operated on to the [MemRefType](../../Dialects/Builtin.md/#memreftype). Let's +look at a snippet of lowering the `toy.transpose` operation: ```c++ /// Lower the `toy.transpose` operation to an affine loop nest. @@ -185,29 +184,29 @@ * Generate `load` operations from the buffer - One option is to generate `load` operations from the buffer type to materialize - an instance of the value type. This allows for the definition of the `toy.print` - operation to remain unchanged. The downside to this approach is that the - optimizations on the `affine` dialect are limited, because the `load` will - actually involve a full copy that is only visible *after* our optimizations have - been performed. + One option is to generate `load` operations from the buffer type to + materialize an instance of the value type. This allows for the definition of + the `toy.print` operation to remain unchanged. The downside to this approach + is that the optimizations on the `affine` dialect are limited, because the + `load` will actually involve a full copy that is only visible *after* our + optimizations have been performed. * Generate a new version of `toy.print` that operates on the lowered type - Another option would be to have another, lowered, variant of `toy.print` that - operates on the lowered type. The benefit of this option is that there is no - hidden, unnecessary copy to the optimizer. The downside is that another - operation definition is needed that may duplicate many aspects of the first. - Defining a base class in [ODS](../../OpDefinitions.md) may simplify this, but - you still need to treat these operations separately. + Another option would be to have another, lowered, variant of `toy.print` + that operates on the lowered type. The benefit of this option is that there + is no hidden, unnecessary copy to the optimizer. The downside is that + another operation definition is needed that may duplicate many aspects of + the first. Defining a base class in [ODS](../../OpDefinitions.md) may + simplify this, but you still need to treat these operations separately. * Update `toy.print` to allow for operating on the lowered type - A third option is to update the current definition of `toy.print` to allow for - operating the on the lowered type. The benefit of this approach is that it is - simple, does not introduce an additional hidden copy, and does not require - another operation definition. The downside to this option is that it requires - mixing abstraction levels in the `Toy` dialect. + A third option is to update the current definition of `toy.print` to allow + for operating the on the lowered type. The benefit of this approach is that + it is simple, does not introduce an additional hidden copy, and does not + require another operation definition. The downside to this option is that it + requires mixing abstraction levels in the `Toy` dialect. For the sake of simplicity, we will use the third option for this lowering. This involves updating the type constraints on the PrintOp in the operation @@ -241,17 +240,17 @@ ```mlir func @main() { - %cst = constant 1.000000e+00 : f64 - %cst_0 = constant 2.000000e+00 : f64 - %cst_1 = constant 3.000000e+00 : f64 - %cst_2 = constant 4.000000e+00 : f64 - %cst_3 = constant 5.000000e+00 : f64 - %cst_4 = constant 6.000000e+00 : f64 + %cst = arith.constant 1.000000e+00 : f64 + %cst_0 = arith.constant 2.000000e+00 : f64 + %cst_1 = arith.constant 3.000000e+00 : f64 + %cst_2 = arith.constant 4.000000e+00 : f64 + %cst_3 = arith.constant 5.000000e+00 : f64 + %cst_4 = arith.constant 6.000000e+00 : f64 // Allocating buffers for the inputs and outputs. - %0 = alloc() : memref<3x2xf64> - %1 = alloc() : memref<3x2xf64> - %2 = alloc() : memref<2x3xf64> + %0 = memref.alloc() : memref<3x2xf64> + %1 = memref.alloc() : memref<3x2xf64> + %2 = memref.alloc() : memref<2x3xf64> // Initialize the input buffer with the constant values. affine.store %cst, %2[0, 0] : memref<2x3xf64> @@ -275,16 +274,16 @@ affine.for %arg1 = 0 to 2 { %3 = affine.load %1[%arg0, %arg1] : memref<3x2xf64> %4 = affine.load %1[%arg0, %arg1] : memref<3x2xf64> - %5 = mulf %3, %4 : f64 + %5 = arith.mulf %3, %4 : f64 affine.store %5, %0[%arg0, %arg1] : memref<3x2xf64> } } // Print the value held by the buffer. toy.print %0 : memref<3x2xf64> - dealloc %2 : memref<2x3xf64> - dealloc %1 : memref<3x2xf64> - dealloc %0 : memref<3x2xf64> + memref.dealloc %2 : memref<2x3xf64> + memref.dealloc %1 : memref<3x2xf64> + memref.dealloc %0 : memref<3x2xf64> return } ``` @@ -299,16 +298,16 @@ ```mlir func @main() { - %cst = constant 1.000000e+00 : f64 - %cst_0 = constant 2.000000e+00 : f64 - %cst_1 = constant 3.000000e+00 : f64 - %cst_2 = constant 4.000000e+00 : f64 - %cst_3 = constant 5.000000e+00 : f64 - %cst_4 = constant 6.000000e+00 : f64 + %cst = arith.constant 1.000000e+00 : f64 + %cst_0 = arith.constant 2.000000e+00 : f64 + %cst_1 = arith.constant 3.000000e+00 : f64 + %cst_2 = arith.constant 4.000000e+00 : f64 + %cst_3 = arith.constant 5.000000e+00 : f64 + %cst_4 = arith.constant 6.000000e+00 : f64 // Allocating buffers for the inputs and outputs. - %0 = alloc() : memref<3x2xf64> - %1 = alloc() : memref<2x3xf64> + %0 = memref.alloc() : memref<3x2xf64> + %1 = memref.alloc() : memref<2x3xf64> // Initialize the input buffer with the constant values. affine.store %cst, %1[0, 0] : memref<2x3xf64> @@ -324,15 +323,15 @@ %2 = affine.load %1[%arg1, %arg0] : memref<2x3xf64> // Multiply and store into the output buffer. - %3 = mulf %2, %2 : f64 + %3 = arith.mulf %2, %2 : f64 affine.store %3, %0[%arg0, %arg1] : memref<3x2xf64> } } // Print the value held by the buffer. toy.print %0 : memref<3x2xf64> - dealloc %1 : memref<2x3xf64> - dealloc %0 : memref<3x2xf64> + memref.dealloc %1 : memref<2x3xf64> + memref.dealloc %0 : memref<3x2xf64> return } ``` diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md --- a/mlir/docs/Tutorials/Toy/Ch-6.md +++ b/mlir/docs/Tutorials/Toy/Ch-6.md @@ -16,12 +16,13 @@ Before going over the conversion to LLVM, let's lower the `toy.print` operation. We will lower this operation to a non-affine loop nest that invokes `printf` for each element. Note that, because the dialect conversion framework supports -[transitive lowering](../../../getting_started/Glossary.md/#transitive-lowering), we don't need to -directly emit operations in the LLVM dialect. By transitive lowering, we mean -that the conversion framework may apply multiple patterns to fully legalize an -operation. In this example, we are generating a structured loop nest instead of -the branch-form in the LLVM dialect. As long as we then have a lowering from the -loop operations to LLVM, the lowering will still succeed. +[transitive lowering](../../../getting_started/Glossary.md/#transitive-lowering), +we don't need to directly emit operations in the LLVM dialect. By transitive +lowering, we mean that the conversion framework may apply multiple patterns to +fully legalize an operation. In this example, we are generating a structured +loop nest instead of the branch-form in the LLVM dialect. As long as we then +have a lowering from the loop operations to LLVM, the lowering will still +succeed. During lowering we can get, or build, the declaration for printf as so: @@ -84,15 +85,17 @@ Now that the conversion target has been defined, we need to provide the patterns used for lowering. At this point in the compilation process, we have a -combination of `toy`, `affine`, and `std` operations. Luckily, the `std` and -`affine` dialects already provide the set of patterns needed to transform them -into LLVM dialect. These patterns allow for lowering the IR in multiple stages -by relying on [transitive lowering](../../../getting_started/Glossary.md/#transitive-lowering). +combination of `toy`, `affine`, `arith`, and `std` operations. Luckily, the +`affine`, `arith`, and `std` dialects already provide the set of patterns needed +to transform them into LLVM dialect. These patterns allow for lowering the IR in +multiple stages by relying on +[transitive lowering](../../../getting_started/Glossary.md/#transitive-lowering). ```c++ mlir::RewritePatternSet patterns(&getContext()); mlir::populateAffineToStdConversionPatterns(patterns, &getContext()); mlir::populateLoopToStdConversionPatterns(patterns, &getContext()); + mlir::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns); // The only remaining operation, to lower from the `toy` dialect, is the @@ -200,7 +203,7 @@ %106 = mul i64 %100, 1 %107 = add i64 %105, %106 %108 = getelementptr double, double* %103, i64 %107 - %109 = load double, double* %108 + %109 = memref.load double, double* %108 %110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109) %111 = add i64 %100, 1 br label %99 @@ -322,7 +325,7 @@ [`--print-ir-after-all`](../../PassManagement.md/#ir-printing) to track the evolution of the IR throughout the pipeline. -The example code used throughout this section can be found in +The example code used throughout this section can be found in test/Examples/Toy/Ch6/llvm-lowering.mlir. So far, we have worked with primitive data types. In the diff --git a/mlir/docs/includes/img/branch_example_post_move.svg b/mlir/docs/includes/img/branch_example_post_move.svg --- a/mlir/docs/includes/img/branch_example_post_move.svg +++ b/mlir/docs/includes/img/branch_example_post_move.svg @@ -414,6 +414,6 @@ id="tspan3407" x="21.911886" y="15.884925" - style="font-size:5.64444px;fill:#008000;stroke-width:0.264583">%0 = alloc() + style="font-size:5.64444px;fill:#008000;stroke-width:0.264583">%0 = memref.alloc() diff --git a/mlir/docs/includes/img/branch_example_pre_move.svg b/mlir/docs/includes/img/branch_example_pre_move.svg --- a/mlir/docs/includes/img/branch_example_pre_move.svg +++ b/mlir/docs/includes/img/branch_example_pre_move.svg @@ -353,7 +353,7 @@ transform="translate(8.4353227,-0.28369449)">%0 = alloc()%0 = memref.alloc() %1 = alloc(%0)%1 = memref.alloc(%0)%5 = alloc(%d0)%5 = memref.alloc(%d0)%6 = alloc(%d1)%6 = memref.alloc(%d1)%1 = alloc(%0)%1 = memref.alloc(%0)(); - registry.insert(); + registry.insert(); // Add the following to include *all* MLIR Core dialects, or selectively // include what you need like above. You only need to register dialects that // will be *parsed* by the tool, not the one generated diff --git a/mlir/examples/standalone/test/Standalone/dummy.mlir b/mlir/examples/standalone/test/Standalone/dummy.mlir --- a/mlir/examples/standalone/test/Standalone/dummy.mlir +++ b/mlir/examples/standalone/test/Standalone/dummy.mlir @@ -3,7 +3,7 @@ module { // CHECK-LABEL: func @bar() func @bar() { - %0 = constant 1 : i32 + %0 = arith.constant 1 : i32 // CHECK: %{{.*}} = standalone.foo %{{.*}} : i32 %res = standalone.foo %0 : i32 return diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -16,6 +16,7 @@ #include "toy/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" @@ -124,8 +125,8 @@ return success(); } }; -using AddOpLowering = BinaryOpLowering; -using MulOpLowering = BinaryOpLowering; +using AddOpLowering = BinaryOpLowering; +using MulOpLowering = BinaryOpLowering; //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Constant operations @@ -154,10 +155,12 @@ if (!valueShape.empty()) { for (auto i : llvm::seq( 0, *std::max_element(valueShape.begin(), valueShape.end()))) - constantIndices.push_back(rewriter.create(loc, i)); + constantIndices.push_back( + rewriter.create(loc, i)); } else { // This is the case of a tensor of rank 0. - constantIndices.push_back(rewriter.create(loc, 0)); + constantIndices.push_back( + rewriter.create(loc, 0)); } // The constant operation represents a multi-dimensional constant, so we @@ -171,7 +174,7 @@ // we store the element at the given index. if (dimension == valueShape.size()) { rewriter.create( - loc, rewriter.create(loc, *valueIt++), alloc, + loc, rewriter.create(loc, *valueIt++), alloc, llvm::makeArrayRef(indices)); return; } @@ -284,9 +287,9 @@ // We define the specific operations, or dialects, that are legal targets for // this lowering. In our case, we are lowering to a combination of the - // `Affine`, `MemRef` and `Standard` dialects. - target.addLegalDialect(); + // `Affine`, `Arithmetic`, `MemRef`, and `Standard` dialects. + target.addLegalDialect(); // We also define the Toy dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. Given that we actually want diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -16,6 +16,7 @@ #include "toy/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" @@ -124,8 +125,8 @@ return success(); } }; -using AddOpLowering = BinaryOpLowering; -using MulOpLowering = BinaryOpLowering; +using AddOpLowering = BinaryOpLowering; +using MulOpLowering = BinaryOpLowering; //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Constant operations @@ -154,10 +155,12 @@ if (!valueShape.empty()) { for (auto i : llvm::seq( 0, *std::max_element(valueShape.begin(), valueShape.end()))) - constantIndices.push_back(rewriter.create(loc, i)); + constantIndices.push_back( + rewriter.create(loc, i)); } else { // This is the case of a tensor of rank 0. - constantIndices.push_back(rewriter.create(loc, 0)); + constantIndices.push_back( + rewriter.create(loc, 0)); } // The constant operation represents a multi-dimensional constant, so we // will need to generate a store for each of the elements. The following @@ -170,7 +173,7 @@ // we store the element at the given index. if (dimension == valueShape.size()) { rewriter.create( - loc, rewriter.create(loc, *valueIt++), alloc, + loc, rewriter.create(loc, *valueIt++), alloc, llvm::makeArrayRef(indices)); return; } @@ -283,9 +286,9 @@ // We define the specific operations, or dialects, that are legal targets for // this lowering. In our case, we are lowering to a combination of the - // `Affine`, `MemRef` and `Standard` dialects. - target.addLegalDialect(); + // `Affine`, `Arithmetic`, `MemRef`, and `Standard` dialects. + target.addLegalDialect(); // We also define the Toy dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. Given that we actually want diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -25,6 +25,7 @@ #include "toy/Passes.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" @@ -32,6 +33,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -73,9 +75,10 @@ // Create a loop for each of the dimensions within the shape. SmallVector loopIvs; for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { - auto lowerBound = rewriter.create(loc, 0); - auto upperBound = rewriter.create(loc, memRefShape[i]); - auto step = rewriter.create(loc, 1); + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = + rewriter.create(loc, memRefShape[i]); + auto step = rewriter.create(loc, 1); auto loop = rewriter.create(loc, lowerBound, upperBound, step); for (Operation &nested : *loop.getBody()) @@ -198,6 +201,8 @@ RewritePatternSet patterns(&getContext()); populateAffineToStdConversionPatterns(patterns); populateLoopToStdConversionPatterns(patterns); + mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, + patterns); populateMemRefToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns); diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -16,6 +16,7 @@ #include "toy/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" @@ -124,8 +125,8 @@ return success(); } }; -using AddOpLowering = BinaryOpLowering; -using MulOpLowering = BinaryOpLowering; +using AddOpLowering = BinaryOpLowering; +using MulOpLowering = BinaryOpLowering; //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Constant operations @@ -154,10 +155,12 @@ if (!valueShape.empty()) { for (auto i : llvm::seq( 0, *std::max_element(valueShape.begin(), valueShape.end()))) - constantIndices.push_back(rewriter.create(loc, i)); + constantIndices.push_back( + rewriter.create(loc, i)); } else { // This is the case of a tensor of rank 0. - constantIndices.push_back(rewriter.create(loc, 0)); + constantIndices.push_back( + rewriter.create(loc, 0)); } // The constant operation represents a multi-dimensional constant, so we @@ -171,7 +174,7 @@ // we store the element at the given index. if (dimension == valueShape.size()) { rewriter.create( - loc, rewriter.create(loc, *valueIt++), alloc, + loc, rewriter.create(loc, *valueIt++), alloc, llvm::makeArrayRef(indices)); return; } @@ -284,9 +287,9 @@ // We define the specific operations, or dialects, that are legal targets for // this lowering. In our case, we are lowering to a combination of the - // `Affine`, `MemRef` and `Standard` dialects. - target.addLegalDialect(); + // `Affine`, `Arithmetic`, `MemRef`, and `Standard` dialects. + target.addLegalDialect(); // We also define the Toy dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. Given that we actually want diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -25,6 +25,7 @@ #include "toy/Passes.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" @@ -32,6 +33,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -73,9 +75,10 @@ // Create a loop for each of the dimensions within the shape. SmallVector loopIvs; for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { - auto lowerBound = rewriter.create(loc, 0); - auto upperBound = rewriter.create(loc, memRefShape[i]); - auto step = rewriter.create(loc, 1); + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = + rewriter.create(loc, memRefShape[i]); + auto step = rewriter.create(loc, 1); auto loop = rewriter.create(loc, lowerBound, upperBound, step); for (Operation &nested : *loop.getBody()) @@ -198,6 +201,8 @@ RewritePatternSet patterns(&getContext()); populateAffineToStdConversionPatterns(patterns); populateLoopToStdConversionPatterns(patterns); + mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, + patterns); populateMemRefToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns); diff --git a/mlir/include/mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h b/mlir/include/mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h @@ -0,0 +1,28 @@ +//===- ArithmeticToLLVM.h - Arith to LLVM dialect conversion ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHMETICTOLLVM_ARITHMETICTOLLVM_H +#define MLIR_CONVERSION_ARITHMETICTOLLVM_ARITHMETICTOLLVM_H + +#include + +namespace mlir { + +class LLVMTypeConverter; +class RewritePatternSet; +class Pass; + +namespace arith { +void populateArithmeticToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +std::unique_ptr createConvertArithmeticToLLVMPass(); +} // end namespace arith +} // end namespace mlir + +#endif // MLIR_CONVERSION_ARITHMETICTOLLVM_ARITHMETICTOLLVM_H diff --git a/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h b/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h @@ -0,0 +1,28 @@ +//===- ArithmeticToSPIRV.h - Convert Arith to SPIRV dialect -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHMETICTOSPIRV_ARITHMETICTOSPIRV_H +#define MLIR_CONVERSION_ARITHMETICTOSPIRV_ARITHMETICTOSPIRV_H + +#include + +namespace mlir { + +class SPIRVTypeConverter; +class RewritePatternSet; +class Pass; + +namespace arith { +void populateArithmeticToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns); + +std::unique_ptr createConvertArithmeticToSPIRVPass(); +} // end namespace arith +} // end namespace mlir + +#endif // MLIR_CONVERSION_ARITHMETICTOSPIRV_ARITHMETICTOSPIRV_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -10,6 +10,8 @@ #define MLIR_CONVERSION_PASSES_H #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" +#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -39,10 +39,10 @@ %d0 = <...> %d1 = <...> %s0 = <...> - %0 = constant 2 : index - %1 = muli %0, %d1 - %2 = addi %d0, %1 - %r = addi %2, %s0 + %0 = arith.constant 2 : index + %1 = arith.muli %0, %d1 + %2 = arith.addi %d0, %1 + %r = arith.addi %2, %s0 ``` #### Input invariant @@ -74,6 +74,40 @@ ]; } +//===----------------------------------------------------------------------===// +// ArithmeticToLLVM +//===----------------------------------------------------------------------===// + +def ConvertArithmeticToLLVM : FunctionPass<"convert-arith-to-llvm"> { + let summary = "Convert Arithmetic dialect to LLVM dialect"; + let description = [{ + This pass converts supported Arithmetic ops to LLVM dialect instructions. + }]; + let constructor = "mlir::arith::createConvertArithmeticToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; + let options = [ + Option<"indexBitwidth", "index-bitwidth", "unsigned", + /*default=kDeriveIndexBitwidthFromDataLayout*/"0", + "Bitwidth of the index type, 0 to use size of machine word">, + ]; +} + +//===----------------------------------------------------------------------===// +// ArithmeticToSPIRV +//===----------------------------------------------------------------------===// + +def ConvertArithmeticToSPIRV : FunctionPass<"convert-arith-to-spirv"> { + let summary = "Convert Arithmetic dialect to SPIR-V dialect"; + let constructor = "mlir::arith::createConvertArithmeticToSPIRVPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; + let options = [ + Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types", + "bool", /*default=*/"true", + "Emulate non-32-bit scalar types with 32-bit ones if " + "missing native support"> + ]; +} + //===----------------------------------------------------------------------===// // AsyncToLLVM //===----------------------------------------------------------------------===// @@ -86,7 +120,10 @@ API to execute them. }]; let constructor = "mlir::createConvertAsyncToLLVMPass()"; - let dependentDialects = ["LLVM::LLVMDialect"]; + let dependentDialects = [ + "arith::ArithmeticDialect", + "LLVM::LLVMDialect", + ]; } //===----------------------------------------------------------------------===// @@ -106,11 +143,7 @@ def ConvertComplexToStandard : FunctionPass<"convert-complex-to-standard"> { let summary = "Convert Complex dialect to standard dialect"; let constructor = "mlir::createConvertComplexToStandardPass()"; - let dependentDialects = [ - "complex::ComplexDialect", - "math::MathDialect", - "StandardOpsDialect" - ]; + let dependentDialects = ["math::MathDialect"]; } //===----------------------------------------------------------------------===// @@ -136,7 +169,11 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> { let summary = "Generate NVVM operations for gpu operations"; let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()"; - let dependentDialects = ["NVVM::NVVMDialect", "memref::MemRefDialect"]; + let dependentDialects = [ + "memref::MemRefDialect", + "NVVM::NVVMDialect", + "StandardOpsDialect", + ]; let options = [ Option<"indexBitwidth", "index-bitwidth", "unsigned", /*default=kDeriveIndexBitwidthFromDataLayout*/"0", @@ -252,7 +289,11 @@ This pass converts supported Math ops to libm calls. }]; let constructor = "mlir::createConvertMathToLibmPass()"; - let dependentDialects = ["StandardOpsDialect", "vector::VectorDialect"]; + let dependentDialects = [ + "arith::ArithmeticDialect", + "StandardOpsDialect", + "vector::VectorDialect", + ]; } //===----------------------------------------------------------------------===// @@ -448,7 +489,6 @@ let dependentDialects = [ "StandardOpsDialect", "scf::SCFDialect", - "tensor::TensorDialect" ]; } @@ -583,7 +623,11 @@ def TosaToStandard : Pass<"tosa-to-standard"> { let summary = "Lower TOSA to the Standard dialect"; - let dependentDialects = ["StandardOpsDialect", "tensor::TensorDialect"]; + let dependentDialects = [ + "arith::ArithmeticDialect", + "StandardOpsDialect", + "tensor::TensorDialect", + ]; let description = [{ Pass that converts TOSA operations to the equivalent operations using the operations in the Standard dialect. diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h --- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h +++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h @@ -37,7 +37,7 @@ /// affine.for %I = 0 to 9 { /// %dim = dim %A, 0 : memref /// %add = affine.apply %I + %a -/// %cmp = cmpi "slt", %add, %dim : index +/// %cmp = arith.cmpi "slt", %add, %dim : index /// scf.if %cmp { /// %vec_2d = load %1[%I] : memref<9xvector<17x15xf32>> /// vector.transfer_write %vec_2d, %A[%add, %b, %c] : diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -23,6 +23,7 @@ let name = "affine"; let cppNamespace = "mlir"; let hasConstantMaterializer = 1; + let dependentDialects = ["arith::ArithmeticDialect"]; } // Base class for Affine dialect ops. @@ -201,7 +202,7 @@ %sum = affine.for %i = 0 to 10 step 2 iter_args(%sum_iter = %sum_0) -> (f32) { %t = affine.load %buffer[%i] : memref<1024xf32> - %sum_next = addf %sum_iter, %t : f32 + %sum_next = arith.addf %sum_iter, %t : f32 // Yield current iteration sum to next iteration %sum_iter or to %sum // if final iteration. affine.yield %sum_next : f32 @@ -213,8 +214,8 @@ ```mlir %res:2 = affine.for %i = 0 to 128 iter_args(%arg0 = %init0, %arg1 = %init1) -> (index, index) { - %y0 = addi %arg0, %c1 : index - %y1 = addi %arg1, %c2 : index + %y0 = arith.addi %arg0, %c1 : index + %y1 = arith.addi %arg1, %c2 : index affine.yield %y0, %y1 : index, index } ``` @@ -656,7 +657,7 @@ %0 = affine.parallel (%kx, %ky) = (0, 0) to (2, 2) reduce ("addf") { %1 = affine.load %D[%x + %kx, %y + %ky] : memref<100x100xf32> %2 = affine.load %K[%kx, %ky] : memref<3x3xf32> - %3 = mulf %1, %2 : f32 + %3 = arith.mulf %1, %2 : f32 affine.yield %3 : f32 } affine.store %0, O[%x, %y] : memref<98x98xf32> diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -112,7 +112,7 @@ affine.for %i1 = 0 to 10 { affine.store %cf7, %m[%i0, %i1] : memref<10x10xf32> %v0 = affine.load %m[%i0, %i1] : memref<10x10xf32> - %v1 = addf %v0, %v0 : f32 + %v1 = arith.addf %v0, %v0 : f32 } } return %m : memref<10x10xf32> @@ -129,7 +129,7 @@ affine.for %arg0 = 0 to 10 { affine.for %arg1 = 0 to 10 { affine.store %cst, %0[%arg0, %arg1] : memref<10x10xf32> - %1 = addf %cst, %cst : f32 + %1 = arith.addf %cst, %cst : f32 } } return %0 : memref<10x10xf32> diff --git a/mlir/include/mlir/Dialect/Arithmetic/CMakeLists.txt b/mlir/include/mlir/Dialect/Arithmetic/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Arithmetic/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Arithmetic/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h --- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h @@ -10,6 +10,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" @@ -33,6 +34,64 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.h.inc" +namespace mlir { +namespace arith { + +/// Specialization of `arith.constant` op that returns an integer value. +class ConstantIntOp : public arith::ConstantOp { +public: + using arith::ConstantOp::ConstantOp; + + /// Build a constant int op that produces an integer of the specified width. + static void build(OpBuilder &builder, OperationState &result, int64_t value, + unsigned width); + + /// Build a constant int op that produces an integer of the specified type, + /// which must be an integer type. + static void build(OpBuilder &builder, OperationState &result, int64_t value, + Type type); + + inline int64_t value() { + return arith::ConstantOp::value().cast().getInt(); + } + + static bool classof(Operation *op); +}; + +/// Specialization of `arith.constant` op that returns a floating point value. +class ConstantFloatOp : public arith::ConstantOp { +public: + using arith::ConstantOp::ConstantOp; + + /// Build a constant float op that produces a float of the specified type. + static void build(OpBuilder &builder, OperationState &result, + const APFloat &value, FloatType type); + + inline APFloat value() { + return arith::ConstantOp::value().cast().getValue(); + } + + static bool classof(Operation *op); +}; + +/// Specialization of `arith.constant` op that returns an integer of index type. +class ConstantIndexOp : public arith::ConstantOp { +public: + using arith::ConstantOp::ConstantOp; + + /// Build a constant int op that produces an index. + static void build(OpBuilder &builder, OperationState &result, int64_t value); + + inline int64_t value() { + return arith::ConstantOp::value().cast().getInt(); + } + + static bool classof(Operation *op); +}; + +} // end namespace arith +} // end namespace mlir + //===----------------------------------------------------------------------===// // Utility Functions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td @@ -20,6 +20,8 @@ ops, bitwise and shift ops, cast ops, and compare ops. Operations in this dialect also accept vectors and tensors of integers or floats. }]; + + let hasConstantMaterializer = 1; } // The predicate indicates the type of the comparison to perform: diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -13,6 +13,7 @@ include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" +include "mlir/IR/OpAsmInterface.td" // Base class for Arithmetic dialect ops. Ops in this dialect have no side // effects and can be applied element-wise to vectors and tensors. @@ -119,12 +120,14 @@ //===----------------------------------------------------------------------===// def Arith_ConstantOp : Op, + TypesMatchWith< + "result and attribute have the same type", "value", "result", "$_self">]> { let summary = "integer or floating point constant"; let description = [{ - The `const` operation produces an SSA value equal to some integer or + The `constant` operation produces an SSA value equal to some integer or floating-point constant specified by an attribute. This is the way MLIR forms simple integer and floating point constants. @@ -140,7 +143,14 @@ }]; let arguments = (ins AnyAttr:$value); - let results = (outs SignlessIntegerOrFloatLike:$result); + // TODO: Disallow arith.constant to return anything other than a signless + // integer or float like. Downstream users of Arithmetic should only be + // working with signless integers, floats, or vectors/tensors thereof. + // However, it is necessary to allow arith.constant to return vectors/tensors + // of strings and signed/unsigned integers (for now) as an artefact of + // splitting the Standard dialect. + let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result); + let verifier = [{ return ::verify(*this); }]; let builders = [ OpBuilder<(ins "Attribute":$value), @@ -149,6 +159,12 @@ [{ build($_builder, $_state, type, value); }]>, ]; + let extraClassDeclaration = [{ + /// Whether the constant op can be constructed with a particular value and + /// type. + static bool isBuildableWith(Attribute value, Type type); + }]; + let hasFolder = 1; let assemblyFormat = "attr-dict $value"; } @@ -351,13 +367,13 @@ ```mlir // Scalar signed integer division remainder. - %a = remsi %b, %c : i64 + %a = arith.remsi %b, %c : i64 // SIMD vector element-wise division remainder. - %f = remsi %g, %h : vector<4xi32> + %f = arith.remsi %g, %h : vector<4xi32> // Tensor element-wise integer division remainder. - %x = remsi %y, %z : tensor<4x?xi8> + %x = arith.remsi %y, %z : tensor<4x?xi8> ``` }]; let hasFolder = 1; @@ -717,10 +733,10 @@ ```mlir %1 = arith.constant 21 : i5 // %1 is 0b10101 - %2 = trunci %1 : i5 to i4 // %2 is 0b0101 - %3 = trunci %1 : i5 to i3 // %3 is 0b101 + %2 = arith.trunci %1 : i5 to i4 // %2 is 0b0101 + %3 = arith.trunci %1 : i5 to i3 // %3 is 0b101 - %5 = trunci %0 : vector<2 x i32> to vector<2 x i16> + %5 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16> ``` }]; @@ -803,7 +819,14 @@ // IndexCastOp //===----------------------------------------------------------------------===// -def Arith_IndexCastOp : Arith_IToICastOp<"index_cast"> { +// Index cast can convert between memrefs of signless integers and indices too. +def IndexCastTypeConstraint : TypeConstraint.predicate]>, + "signless-integer-like or memref of signless-integer">; + +def Arith_IndexCastOp : Arith_CastOp<"index_cast", IndexCastTypeConstraint, + IndexCastTypeConstraint> { let summary = "cast between index and integer types"; let description = [{ Casts between scalar or vector integers and corresponding 'index' scalar or @@ -820,8 +843,15 @@ // BitcastOp //===----------------------------------------------------------------------===// -def Arith_BitcastOp : Arith_CastOp<"bitcast", SignlessIntegerOrFloatLike, - SignlessIntegerOrFloatLike> { +// Bitcast can convert between memrefs of signless integers, indices, and +// floats too. +def BitcastTypeConstraint : TypeConstraint.predicate]>, + "signless-integer-or-float-like or memref of signless-integer or float">; + +def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint, + BitcastTypeConstraint> { let summary = "bitcast between values of equal bit width"; let description = [{ Bitcast an integer or floating point value to an integer or floating point @@ -927,10 +957,10 @@ let extraClassDeclaration = [{ static StringRef getPredicateAttrName() { return "predicate"; } - static CmpIPredicate getPredicateByName(StringRef name); + static arith::CmpIPredicate getPredicateByName(StringRef name); - CmpIPredicate getPredicate() { - return (CmpIPredicate) (*this)->getAttrOfType( + arith::CmpIPredicate getPredicate() { + return (arith::CmpIPredicate) (*this)->getAttrOfType( getPredicateAttrName()).getInt(); } }]; @@ -983,10 +1013,10 @@ let extraClassDeclaration = [{ static StringRef getPredicateAttrName() { return "predicate"; } - static CmpFPredicate getPredicateByName(StringRef name); + static arith::CmpFPredicate getPredicateByName(StringRef name); - CmpFPredicate getPredicate() { - return (CmpFPredicate) (*this)->getAttrOfType( + arith::CmpFPredicate getPredicate() { + return (arith::CmpFPredicate) (*this)->getAttrOfType( getPredicateAttrName()).getInt(); } }]; diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Arithmetic/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Arithmetic) +add_public_tablegen_target(MLIRArithmeticTransformsIncGen) + +add_mlir_doc(Passes ArithmeticPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h @@ -0,0 +1,42 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES_H_ +#define MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Bufferize.h" + +namespace mlir { +namespace arith { + +/// Add patterns to bufferize Arithmetic ops. +void populateArithmeticBufferizePatterns(BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns); + +/// Create a pass to bufferize Arithmetic ops. +std::unique_ptr createArithmeticBufferizePass(); + +/// Add patterns to expand Arithmetic ops for LLVM lowering. +void populateArithmeticExpandOpsPatterns(RewritePatternSet &patterns); + +/// Create a pass to legalize Arithmetic ops for LLVM lowering. +std::unique_ptr createArithmeticExpandOpsPass(); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc" + +} // end namespace arith +} // end namespace mlir + +#endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td @@ -0,0 +1,26 @@ +//===-- Passes.td - Arithmetic pass definition file --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES +#define MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def ArithmeticBufferize : FunctionPass<"arith-bufferize"> { + let summary = "Bufferize Arithmetic dialect ops."; + let constructor = "mlir::arith::createArithmeticBufferizePass()"; + let dependentDialects = ["memref::MemRefDialect"]; +} + +def ArithmeticExpandOps : FunctionPass<"arith-expand"> { + let summary = "Legalize Arithmetic ops to be convertible to LLVM."; + let constructor = "mlir::arith::createArithmeticExpandOpsPass()"; + let dependentDialects = ["StandardOpsDialect"]; +} + +#endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -15,7 +15,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" -include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td" include "mlir/Dialect/ArmSVE/ArmSVEOpBase.td" //===----------------------------------------------------------------------===// @@ -460,24 +460,24 @@ ``` }]; let arguments = (ins - CmpFPredicateAttr:$predicate, + Arith_CmpFPredicateAttr:$predicate, ScalableVectorOf<[AnyFloat]>:$lhs, ScalableVectorOf<[AnyFloat]>:$rhs // TODO: This should support a simple scalar ); let results = (outs ScalableVectorOf<[I1]>:$result); let builders = [ - OpBuilder<(ins "CmpFPredicate":$predicate, "Value":$lhs, + OpBuilder<(ins "arith::CmpFPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ buildScalableCmpFOp($_builder, $_state, predicate, lhs, rhs); }]>]; let extraClassDeclaration = [{ static StringRef getPredicateAttrName() { return "predicate"; } - static CmpFPredicate getPredicateByName(StringRef name); + static arith::CmpFPredicate getPredicateByName(StringRef name); - CmpFPredicate getPredicate() { - return (CmpFPredicate)(*this)->getAttrOfType( + arith::CmpFPredicate getPredicate() { + return (arith::CmpFPredicate) (*this)->getAttrOfType( getPredicateAttrName()).getInt(); } }]; @@ -520,24 +520,24 @@ }]; let arguments = (ins - CmpIPredicateAttr:$predicate, + Arith_CmpIPredicateAttr:$predicate, ScalableVectorOf<[I8, I16, I32, I64]>:$lhs, ScalableVectorOf<[I8, I16, I32, I64]>:$rhs ); let results = (outs ScalableVectorOf<[I1]>:$result); let builders = [ - OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, + OpBuilder<(ins "arith::CmpIPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ buildScalableCmpIOp($_builder, $_state, predicate, lhs, rhs); }]>]; let extraClassDeclaration = [{ static StringRef getPredicateAttrName() { return "predicate"; } - static CmpIPredicate getPredicateByName(StringRef name); + static arith::CmpIPredicate getPredicateByName(StringRef name); - CmpIPredicate getPredicate() { - return (CmpIPredicate)(*this)->getAttrOfType( + arith::CmpIPredicate getPredicate() { + return (arith::CmpIPredicate) (*this)->getAttrOfType( getPredicateAttrName()).getInt(); } }]; diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -32,7 +32,11 @@ "The minimum task size for sharding parallel operation."> ]; - let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"]; + let dependentDialects = [ + "arith::ArithmeticDialect", + "async::AsyncDialect", + "scf::SCFDialect" + ]; } def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> { diff --git a/mlir/include/mlir/Dialect/Complex/IR/Complex.h b/mlir/include/mlir/Dialect/Complex/IR/Complex.h --- a/mlir/include/mlir/Dialect/Complex/IR/Complex.h +++ b/mlir/include/mlir/Dialect/Complex/IR/Complex.h @@ -9,6 +9,8 @@ #ifndef MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_ #define MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_ +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td @@ -18,6 +18,9 @@ The complex dialect is intended to hold complex numbers creation and arithmetic ops. }]; + + let dependentDialects = ["arith::ArithmeticDialect", "StandardOpsDialect"]; + let hasConstantMaterializer = 1; } #endif // COMPLEX_BASE diff --git a/mlir/include/mlir/Dialect/GPU/GPUBase.td b/mlir/include/mlir/Dialect/GPU/GPUBase.td --- a/mlir/include/mlir/Dialect/GPU/GPUBase.td +++ b/mlir/include/mlir/Dialect/GPU/GPUBase.td @@ -51,6 +51,8 @@ /// space. static unsigned getPrivateAddressSpace() { return 5; } }]; + + let dependentDialects = ["arith::ArithmeticDialect"]; } def GPU_AsyncToken : DialectType< diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_GPU_GPUDIALECT_H #define MLIR_DIALECT_GPU_GPUDIALECT_H +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/DLTI/Traits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -627,7 +627,7 @@ %1 = "gpu.all_reduce"(%0) ({}) { op = "add" } : (f32) -> (f32) %2 = "gpu.all_reduce"(%0) ({ ^bb(%lhs : f32, %rhs : f32): - %sum = addf %lhs, %rhs : f32 + %sum = arith.addf %lhs, %rhs : f32 "gpu.yield"(%sum) : (f32) -> () }) : (f32) -> (f32) ``` diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -33,11 +33,16 @@ }]; let cppNamespace = "::mlir::linalg"; let dependentDialects = [ - "AffineDialect", "math::MathDialect", "memref::MemRefDialect", - "StandardOpsDialect", "tensor::TensorDialect" + "arith::ArithmeticDialect", + "AffineDialect", + "math::MathDialect", + "memref::MemRefDialect", + "StandardOpsDialect", + "tensor::TensorDialect", ]; let hasCanonicalizer = 1; let hasOperationAttrVerify = 1; + let hasConstantMaterializer = 1; let extraClassDeclaration = [{ /// Attribute name used to to memoize indexing maps for named ops. constexpr const static ::llvm::StringLiteral diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -283,8 +283,8 @@ outs(%C : memref) {other-optional-attributes} { ^bb0(%a: f32, %b: f32, %c: f32) : - %d = mulf %a, %b: f32 - %e = addf %c, %d: f32 + %d = arith.mulf %a, %b: f32 + %e = arith.addf %c, %d: f32 linalg.yield %e : f32 } ``` @@ -306,8 +306,8 @@ %a = load %A[%m, %k] : memref %b = load %B[%k, %n] : memref %c = load %C[%m, %n] : memref - %d = mulf %a, %b: f32 - %e = addf %c, %d: f32 + %d = arith.mulf %a, %b: f32 + %e = arith.addf %c, %d: f32 store %e, %C[%m, %n] : memref } } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_LINALG_LINALGTYPES_H_ #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -143,7 +143,7 @@ let dependentDialects = [ "linalg::LinalgDialect", "AffineDialect", - "memref::MemRefDialect" + "memref::MemRefDialect", ]; } diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -271,7 +271,7 @@ /// to /// /// %iv = %lb + %procId * %step - /// %cond = cmpi "slt", %iv, %ub + /// %cond = arith.cmpi "slt", %iv, %ub /// scf.if %cond { /// ... /// } diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_MEMREF_IR_MEMREF_H_ #define MLIR_DIALECT_MEMREF_IR_MEMREF_H_ +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td @@ -19,7 +19,7 @@ manipulation ops, which are not strongly associated with any particular other dialect or domain abstraction. }]; - let dependentDialects = ["tensor::TensorDialect"]; + let dependentDialects = ["arith::ArithmeticDialect", "tensor::TensorDialect"]; let hasConstantMaterializer = 1; } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -158,7 +158,7 @@ omp.wsloop (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) { %a = load %arrA[%i1, %i2] : memref %b = load %arrB[%i1, %i2] : memref - %sum = addf %a, %b : f32 + %sum = arith.addf %a, %b : f32 store %sum, %arrC[%i1, %i2] : memref omp.yield } diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -94,18 +94,18 @@ ```mlir # Before: scf.for %i = %c0 to %arg1 step %c1 { - %0 = addi %arg2, %arg2 : i32 + %0 = arith.addi %arg2, %arg2 : i32 memref.store %0, %arg0[%i] : memref } # After: %0 = scf.while (%i = %c0) : (index) -> index { - %1 = cmpi slt, %i, %arg1 : index + %1 = arith.cmpi slt, %i, %arg1 : index scf.condition(%1) %i : index } do { ^bb0(%i: index): // no predecessors - %1 = addi %i, %c1 : index - %2 = addi %arg2, %arg2 : i32 + %1 = arith.addi %i, %c1 : index + %2 = arith.addi %arg2, %arg2 : i32 memref.store %2, %arg0[%i] : memref scf.yield %1 : index } diff --git a/mlir/include/mlir/Dialect/SCF/SCF.h b/mlir/include/mlir/Dialect/SCF/SCF.h --- a/mlir/include/mlir/Dialect/SCF/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/SCF.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SCF_H_ #define MLIR_DIALECT_SCF_H_ +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" @@ -86,9 +87,9 @@ /// expect the body building functions to return their current value. /// The built nested scf::For are captured in `capturedLoops` when non-null. LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, - ValueRange ubs, ValueRange steps, - function_ref - bodyBuilder = nullptr); + ValueRange ubs, ValueRange steps, + function_ref + bodyBuilder = nullptr); } // end namespace scf } // end namespace mlir diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -20,6 +20,7 @@ def SCF_Dialect : Dialect { let name = "scf"; let cppNamespace = "::mlir::scf"; + let dependentDialects = ["arith::ArithmeticDialect"]; } // Base class for SCF dialect ops. @@ -170,7 +171,7 @@ %sum = scf.for %iv = %lb to %ub step %step iter_args(%sum_iter = %sum_0) -> (f32) { %t = load %buffer[%iv] : memref<1024xf32> - %sum_next = addf %sum_iter, %t : f32 + %sum_next = arith.addf %sum_iter, %t : f32 // Yield current iteration sum to next iteration %sum_iter or to %sum // if final iteration. scf.yield %sum_next : f32 @@ -194,9 +195,9 @@ %sum = scf.for %iv = %lb to %ub step %step iter_args(%sum_iter = %sum_0) -> (f32) { %t = load %buffer[%iv] : memref<1024xf32> - %cond = cmpf "ugt", %t, %c0 : f32 + %cond = arith.cmpf "ugt", %t, %c0 : f32 %sum_next = scf.if %cond -> (f32) { - %new_sum = addf %sum_iter, %t : f32 + %new_sum = arith.addf %sum_iter, %t : f32 scf.yield %new_sum : f32 } else { scf.yield %sum_iter : f32 @@ -451,7 +452,7 @@ %elem_to_reduce = load %buffer[%iv] : memref<100xf32> scf.reduce(%elem_to_reduce) : f32 { ^bb0(%lhs : f32, %rhs: f32): - %res = addf %lhs, %rhs : f32 + %res = arith.addf %lhs, %rhs : f32 scf.reduce.return %res : f32 } } @@ -519,7 +520,7 @@ %operand = constant 1.0 : f32 scf.reduce(%operand) : f32 { ^bb0(%lhs : f32, %rhs: f32): - %res = addf %lhs, %rhs : f32 + %res = arith.addf %lhs, %rhs : f32 scf.reduce.return %res : f32 } ``` diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -14,6 +14,7 @@ #ifndef MLIR_SHAPE_IR_SHAPE_H #define MLIR_SHAPE_IR_SHAPE_H +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -35,7 +35,7 @@ }]; let cppNamespace = "::mlir::shape"; - let dependentDialects = ["tensor::TensorDialect"]; + let dependentDialects = ["arith::ArithmeticDialect", "tensor::TensorDialect"]; let hasConstantMaterializer = 1; let hasOperationAttrVerify = 1; diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -43,8 +43,8 @@ ins(%arga, %argb: tensor, tensor) outs(%argx: tensor) { ^bb(%a: f64, %b: f64, %x: f64): - %0 = mulf %a, %b : f64 - %1 = addf %x, %0 : f64 + %0 = arith.mulf %a, %b : f64 + %1 = arith.addf %x, %0 : f64 linalg.yield %1 : f64 } -> tensor return %0 : tensor @@ -54,6 +54,7 @@ let constructor = "mlir::createSparsificationPass()"; let dependentDialects = [ "AffineDialect", + "arith::ArithmeticDialect", "LLVM::LLVMDialect", "memref::MemRefDialect", "scf::SCFDialect", @@ -103,6 +104,7 @@ }]; let constructor = "mlir::createSparseTensorConversionPass()"; let dependentDialects = [ + "arith::ArithmeticDialect", "LLVM::LLVMDialect", "memref::MemRefDialect", "scf::SCFDialect", diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H #define MLIR_DIALECT_STANDARDOPS_IR_OPS_H +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -41,77 +42,15 @@ #include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc" namespace mlir { -/// This is a refinement of the "constant" op for the case where it is -/// returning a float value of FloatType. -/// -/// %1 = "std.constant"(){value: 42.0} : bf16 -/// -class ConstantFloatOp : public ConstantOp { -public: - using ConstantOp::ConstantOp; - - /// Builds a constant float op producing a float of the specified type. - static void build(OpBuilder &builder, OperationState &result, - const APFloat &value, FloatType type); - - APFloat getValue() { - return (*this)->getAttrOfType("value").getValue(); - } - - static bool classof(Operation *op); -}; - -/// This is a refinement of the "constant" op for the case where it is -/// returning an integer value of IntegerType. -/// -/// %1 = "std.constant"(){value: 42} : i32 -/// -class ConstantIntOp : public ConstantOp { -public: - using ConstantOp::ConstantOp; - /// Build a constant int op producing an integer of the specified width. - static void build(OpBuilder &builder, OperationState &result, int64_t value, - unsigned width); - - /// Build a constant int op producing an integer with the specified type, - /// which must be an integer type. - static void build(OpBuilder &builder, OperationState &result, int64_t value, - Type type); - - int64_t getValue() { - return (*this)->getAttrOfType("value").getInt(); - } - - static bool classof(Operation *op); -}; - -/// This is a refinement of the "constant" op for the case where it is -/// returning an integer value of Index type. -/// -/// %1 = "std.constant"(){value: 99} : () -> index -/// -class ConstantIndexOp : public ConstantOp { -public: - using ConstantOp::ConstantOp; - - /// Build a constant int op producing an index. - static void build(OpBuilder &builder, OperationState &result, int64_t value); - - int64_t getValue() { - return (*this)->getAttrOfType("value").getInt(); - } - - static bool classof(Operation *op); -}; /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer /// comparison predicates. -bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, +bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs); /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point /// comparison predicates. -bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, +bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs, const APFloat &rhs); /// Returns the identity value attribute associated with an AtomicRMWKind op. diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -25,6 +25,7 @@ def StandardOps_Dialect : Dialect { let name = "std"; let cppNamespace = "::mlir"; + let dependentDialects = ["arith::ArithmeticDialect"]; let hasConstantMaterializer = 1; } @@ -182,138 +183,6 @@ [DeclareOpInterfaceMethods])>, Arguments<(ins FloatLike:$a, FloatLike:$b, FloatLike:$c)>; -//===----------------------------------------------------------------------===// -// AbsFOp -//===----------------------------------------------------------------------===// - -def AbsFOp : FloatUnaryOp<"absf"> { - let summary = "floating point absolute-value operation"; - let description = [{ - The `absf` operation computes the absolute value. It takes one operand and - returns one result of the same type. This type may be a float scalar type, - a vector whose element type is float, or a tensor of floats. - - Example: - - ```mlir - // Scalar absolute value. - %a = absf %b : f64 - - // SIMD vector element-wise absolute value. - %f = absf %g : vector<4xf32> - - // Tensor element-wise absolute value. - %x = absf %y : tensor<4x?xf8> - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// AddFOp -//===----------------------------------------------------------------------===// - -def AddFOp : FloatBinaryOp<"addf"> { - let summary = "floating point addition operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.addf` ssa-use `,` ssa-use `:` type - ``` - - The `addf` operation takes two operands and returns one result, each of - these is required to be the same type. This type may be a floating point - scalar type, a vector whose element type is a floating point type, or a - floating point tensor. - - Example: - - ```mlir - // Scalar addition. - %a = addf %b, %c : f64 - - // SIMD vector addition, e.g. for Intel SSE. - %f = addf %g, %h : vector<4xf32> - - // Tensor addition. - %x = addf %y, %z : tensor<4x?xbf16> - ``` - - TODO: In the distant future, this will accept optional attributes for fast - math, contraction, rounding mode, and other controls. - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// AddIOp -//===----------------------------------------------------------------------===// - -def AddIOp : IntBinaryOp<"addi", [Commutative]> { - let summary = "integer addition operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.addi` ssa-use `,` ssa-use `:` type - ``` - - The `addi` operation takes two operands and returns one result, each of - these is required to be the same type. This type may be an integer scalar - type, a vector whose element type is integer, or a tensor of integers. It - has no standard attributes. - - Example: - - ```mlir - // Scalar addition. - %a = addi %b, %c : i64 - - // SIMD vector element-wise addition, e.g. for Intel SSE. - %f = addi %g, %h : vector<4xi32> - - // Tensor element-wise addition. - %x = addi %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; - let hasCanonicalizer = 1; -} - -//===----------------------------------------------------------------------===// -// AndOp -//===----------------------------------------------------------------------===// - -def AndOp : IntBinaryOp<"and", [Commutative]> { - let summary = "integer binary and"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.and` ssa-use `,` ssa-use `:` type - ``` - - The `and` operation takes two operands and returns one result, each of these - is required to be the same type. This type may be an integer scalar type, a - vector whose element type is integer, or a tensor of integers. It has no - standard attributes. - - Example: - - ```mlir - // Scalar integer bitwise and. - %a = and %b, %c : i64 - - // SIMD vector element-wise bitwise integer and. - %f = and %g, %h : vector<4xi32> - - // Tensor element-wise bitwise integer and. - %x = and %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - //===----------------------------------------------------------------------===// // AssertOp //===----------------------------------------------------------------------===// @@ -413,7 +282,7 @@ %x = generic_atomic_rmw %I[%i] : memref<10xf32> { ^bb0(%current_value : f32): %c1 = constant 1.0 : f32 - %inc = addf %c1, %current_value : f32 + %inc = arith.addf %c1, %current_value : f32 atomic_yield %inc : f32 } ``` @@ -456,32 +325,6 @@ let assemblyFormat = "$result attr-dict `:` type($result)"; } -//===----------------------------------------------------------------------===// -// BitcastOp -//===----------------------------------------------------------------------===// - -def BitcastOp : ArithmeticCastOp<"bitcast"> { - let summary = "bitcast between values of equal bit width"; - let description = [{ - Bitcast an integer or floating point value to an integer or floating point - value of equal bit width. When operating on vectors, casts elementwise. - - Note that this implements a logical bitcast independent of target - endianness. This allows constant folding without target information and is - consitent with the bitcast constant folders in LLVM (see - https://github.com/llvm/llvm-project/blob/18c19414eb/llvm/lib/IR/ConstantFold.cpp#L168) - For targets where the source and target type have the same endianness (which - is the standard), this cast will also change no bits at runtime, but it may - still require an operation, for example if the machine has different - floating point and integer register files. For targets that have a different - endianness for the source and target types (e.g. float is big-endian and - integer is little-endian) a proper lowering would add operations to swap the - order of words in addition to the bitcast. - }]; - let hasFolder = 1; -} - - //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// @@ -666,240 +509,6 @@ let assemblyFormat = "$callee `(` $operands `)` attr-dict `:` type($callee)"; } -//===----------------------------------------------------------------------===// -// CeilFOp -//===----------------------------------------------------------------------===// - -def CeilFOp : FloatUnaryOp<"ceilf"> { - let summary = "ceiling of the specified value"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.ceilf` ssa-use `:` type - ``` - - The `ceilf` operation computes the ceiling of a given value. It takes one - operand and returns one result of the same type. This type may be a float - scalar type, a vector whose element type is float, or a tensor of floats. - It has no standard attributes. - - Example: - - ```mlir - // Scalar ceiling value. - %a = ceilf %b : f64 - - // SIMD vector element-wise ceiling value. - %f = ceilf %g : vector<4xf32> - - // Tensor element-wise ceiling value. - %x = ceilf %y : tensor<4x?xf8> - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// FloorFOp -//===----------------------------------------------------------------------===// - -def FloorFOp : FloatUnaryOp<"floorf"> { - let summary = "floor of the specified value"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.floorf` ssa-use `:` type - ``` - - The `floorf` operation computes the floor of a given value. It takes one - operand and returns one result of the same type. This type may be a float - scalar type, a vector whose element type is float, or a tensor of floats. - It has no standard attributes. - - Example: - - ```mlir - // Scalar floor value. - %a = floorf %b : f64 - - // SIMD vector element-wise floor value. - %f = floorf %g : vector<4xf32> - - // Tensor element-wise floor value. - %x = floorf %y : tensor<4x?xf8> - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// CmpFOp -//===----------------------------------------------------------------------===// - -def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands, - DeclareOpInterfaceMethods, TypesMatchWith< - "result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> { - let summary = "floating-point comparison operation"; - let description = [{ - The `cmpf` operation compares its two operands according to the float - comparison rules and the predicate specified by the respective attribute. - The predicate defines the type of comparison: (un)orderedness, (in)equality - and signed less/greater than (or equal to) as well as predicates that are - always true or false. The operands must have the same type, and this type - must be a float type, or a vector or tensor thereof. The result is an i1, - or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi, - the operands are always treated as signed. The u prefix indicates - *unordered* comparison, not unsigned comparison, so "une" means unordered or - not equal. For the sake of readability by humans, custom assembly form for - the operation uses a string-typed attribute for the predicate. The value of - this attribute corresponds to lower-cased name of the predicate constant, - e.g., "one" means "ordered not equal". The string representation of the - attribute is merely a syntactic sugar and is converted to an integer - attribute by the parser. - - Example: - - ```mlir - %r1 = cmpf "oeq" %0, %1 : f32 - %r2 = cmpf "ult" %0, %1 : tensor<42x42xf64> - %r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1 - ``` - }]; - - let arguments = (ins - CmpFPredicateAttr:$predicate, - FloatLike:$lhs, - FloatLike:$rhs - ); - let results = (outs BoolLike:$result); - - let builders = [ - OpBuilder<(ins "CmpFPredicate":$predicate, "Value":$lhs, - "Value":$rhs), [{ - ::buildCmpFOp($_builder, $_state, predicate, lhs, rhs); - }]>]; - - let extraClassDeclaration = [{ - static StringRef getPredicateAttrName() { return "predicate"; } - static CmpFPredicate getPredicateByName(StringRef name); - - CmpFPredicate getPredicate() { - return (CmpFPredicate)(*this)->getAttrOfType( - getPredicateAttrName()).getInt(); - } - }]; - - let verifier = [{ return success(); }]; - - let hasFolder = 1; - - let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; -} - -//===----------------------------------------------------------------------===// -// CmpIOp -//===----------------------------------------------------------------------===// - -def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands, - DeclareOpInterfaceMethods, TypesMatchWith< - "result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> { - let summary = "integer comparison operation"; - let description = [{ - The `cmpi` operation is a generic comparison for integer-like types. Its two - arguments can be integers, vectors or tensors thereof as long as their types - match. The operation produces an i1 for the former case, a vector or a - tensor of i1 with the same shape as inputs in the other cases. - - Its first argument is an attribute that defines which type of comparison is - performed. The following comparisons are supported: - - - equal (mnemonic: `"eq"`; integer value: `0`) - - not equal (mnemonic: `"ne"`; integer value: `1`) - - signed less than (mnemonic: `"slt"`; integer value: `2`) - - signed less than or equal (mnemonic: `"sle"`; integer value: `3`) - - signed greater than (mnemonic: `"sgt"`; integer value: `4`) - - signed greater than or equal (mnemonic: `"sge"`; integer value: `5`) - - unsigned less than (mnemonic: `"ult"`; integer value: `6`) - - unsigned less than or equal (mnemonic: `"ule"`; integer value: `7`) - - unsigned greater than (mnemonic: `"ugt"`; integer value: `8`) - - unsigned greater than or equal (mnemonic: `"uge"`; integer value: `9`) - - The result is `1` if the comparison is true and `0` otherwise. For vector or - tensor operands, the comparison is performed elementwise and the element of - the result indicates whether the comparison is true for the operand elements - with the same indices as those of the result. - - Note: while the custom assembly form uses strings, the actual underlying - attribute has integer type (or rather enum class in C++ code) as seen from - the generic assembly form. String literals are used to improve readability - of the IR by humans. - - This operation only applies to integer-like operands, but not floats. The - main reason being that comparison operations have diverging sets of - attributes: integers require sign specification while floats require various - floating point-related particularities, e.g., `-ffast-math` behavior, - IEEE754 compliance, etc - ([rationale](../Rationale/Rationale.md#splitting-floating-point-vs-integer-operations)). - The type of comparison is specified as attribute to avoid introducing ten - similar operations, taking into account that they are often implemented - using the same operation downstream - ([rationale](../Rationale/Rationale.md#specifying-comparison-kind-as-attribute)). The - separation between signed and unsigned order comparisons is necessary - because of integers being signless. The comparison operation must know how - to interpret values with the foremost bit being set: negatives in two's - complement or large positives - ([rationale](../Rationale/Rationale.md#specifying-sign-in-integer-comparison-operations)). - - Example: - - ```mlir - // Custom form of scalar "signed less than" comparison. - %x = cmpi "slt", %lhs, %rhs : i32 - - // Generic form of the same operation. - %x = "std.cmpi"(%lhs, %rhs) {predicate = 2 : i64} : (i32, i32) -> i1 - - // Custom form of vector equality comparison. - %x = cmpi "eq", %lhs, %rhs : vector<4xi64> - - // Generic form of the same operation. - %x = "std.cmpi"(%lhs, %rhs) {predicate = 0 : i64} - : (vector<4xi64>, vector<4xi64>) -> vector<4xi1> - ``` - }]; - - let arguments = (ins - CmpIPredicateAttr:$predicate, - SignlessIntegerLike:$lhs, - SignlessIntegerLike:$rhs - ); - let results = (outs BoolLike:$result); - - let builders = [ - OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, - "Value":$rhs), [{ - ::buildCmpIOp($_builder, $_state, predicate, lhs, rhs); - }]>]; - - let extraClassDeclaration = [{ - static StringRef getPredicateAttrName() { return "predicate"; } - static CmpIPredicate getPredicateByName(StringRef name); - - CmpIPredicate getPredicate() { - return (CmpIPredicate)(*this)->getAttrOfType( - getPredicateAttrName()).getInt(); - } - }]; - - let verifier = [{ return success(); }]; - - let hasFolder = 1; - - let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; -} - //===----------------------------------------------------------------------===// // CondBranchOp //===----------------------------------------------------------------------===// @@ -1095,264 +704,111 @@ } //===----------------------------------------------------------------------===// -// CopySignOp +// MaxFOp //===----------------------------------------------------------------------===// -def CopySignOp : FloatBinaryOp<"copysign"> { - let summary = "A copysign operation"; +def MaxFOp : FloatBinaryOp<"maxf"> { + let summary = "floating-point maximum operation"; let description = [{ Syntax: ``` - operation ::= ssa-id `=` `std.copysign` ssa-use `,` ssa-use `:` type + operation ::= ssa-id `=` `maxf` ssa-use `,` ssa-use `:` type ``` - The `copysign` returns a value with the magnitude of the first operand and - the sign of the second operand. It takes two operands and returns one - result of the same type. This type may be a float scalar type, a vector - whose element type is float, or a tensor of floats. It has no standard - attributes. + Returns the maximum of the two arguments, treating -0.0 as less than +0.0. + If one of the arguments is NaN, then the result is also NaN. Example: ```mlir - // Scalar copysign value. - %a = copysign %b, %c : f64 - - // SIMD vector element-wise copysign value. - %f = copysign %g, %h : vector<4xf32> - - // Tensor element-wise copysign value. - %x = copysign %y, %z : tensor<4x?xf8> + // Scalar floating-point maximum. + %a = maxf %b, %c : f64 ``` }]; } //===----------------------------------------------------------------------===// -// DivFOp -//===----------------------------------------------------------------------===// - -def DivFOp : FloatBinaryOp<"divf"> { - let summary = "floating point division operation"; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// FmaFOp +// MaxSIOp //===----------------------------------------------------------------------===// -def FmaFOp : FloatTernaryOp<"fmaf"> { - let summary = "floating point fused multipy-add operation"; +def MaxSIOp : IntBinaryOp<"maxsi"> { + let summary = "signed integer maximum operation"; let description = [{ Syntax: ``` - operation ::= ssa-id `=` `std.fmaf` ssa-use `,` ssa-use `,` ssa-use `:` type + operation ::= ssa-id `=` `maxsi` ssa-use `,` ssa-use `:` type ``` - The `fmaf` operation takes three operands and returns one result, each of - these is required to be the same type. This type may be a floating point - scalar type, a vector whose element type is a floating point type, or a - floating point tensor. + Returns the larger of %a and %b comparing the values as signed integers. Example: ```mlir - // Scalar fused multiply-add: d = a*b + c - %d = fmaf %a, %b, %c : f64 - - // SIMD vector fused multiply-add, e.g. for Intel SSE. - %i = fmaf %f, %g, %h : vector<4xf32> - - // Tensor fused multiply-add. - %w = fmaf %x, %y, %z : tensor<4x?xbf16> + // Scalar signed integer maximum. + %a = maxsi %b, %c : i64 ``` - - The semantics of the operation correspond to those of the `llvm.fma` - [intrinsic](https://llvm.org/docs/LangRef.html#llvm-fma-intrinsic). In the - particular case of lowering to LLVM, this is guaranteed to lower - to the `llvm.fma.*` intrinsic. }]; } //===----------------------------------------------------------------------===// -// FPExtOp +// MaxUIOp //===----------------------------------------------------------------------===// -def FPExtOp : ArithmeticCastOp<"fpext"> { - let summary = "cast from floating-point to wider floating-point"; +def MaxUIOp : IntBinaryOp<"maxui"> { + let summary = "unsigned integer maximum operation"; let description = [{ - Cast a floating-point value to a larger floating-point-typed value. - The destination type must to be strictly wider than the source type. - When operating on vectors, casts elementwise. - }]; -} + Syntax: -//===----------------------------------------------------------------------===// -// FPToSIOp -//===----------------------------------------------------------------------===// + ``` + operation ::= ssa-id `=` `maxui` ssa-use `,` ssa-use `:` type + ``` -def FPToSIOp : ArithmeticCastOp<"fptosi"> { - let summary = "cast from floating-point type to integer type"; - let description = [{ - Cast from a value interpreted as floating-point to the nearest (rounding - towards zero) signed integer value. When operating on vectors, casts - elementwise. - }]; -} + Returns the larger of %a and %b comparing the values as unsigned integers. -//===----------------------------------------------------------------------===// -// FPToUIOp -//===----------------------------------------------------------------------===// + Example: -def FPToUIOp : ArithmeticCastOp<"fptoui"> { - let summary = "cast from floating-point type to integer type"; - let description = [{ - Cast from a value interpreted as floating-point to the nearest (rounding - towards zero) unsigned integer value. When operating on vectors, casts - elementwise. + ```mlir + // Scalar unsigned integer maximum. + %a = maxui %b, %c : i64 + ``` }]; } //===----------------------------------------------------------------------===// -// FPTruncOp +// MinFOp //===----------------------------------------------------------------------===// -def FPTruncOp : ArithmeticCastOp<"fptrunc"> { - let summary = "cast from floating-point to narrower floating-point"; +def MinFOp : FloatBinaryOp<"minf"> { + let summary = "floating-point minimum operation"; let description = [{ - Truncate a floating-point value to a smaller floating-point-typed value. - The destination type must be strictly narrower than the source type. - If the value cannot be exactly represented, it is rounded using the default - rounding mode. When operating on vectors, casts elementwise. - }]; + Syntax: - let hasFolder = 1; + ``` + operation ::= ssa-id `=` `minf` ssa-use `,` ssa-use `:` type + ``` + + Returns the minimum of the two arguments, treating -0.0 as less than +0.0. + If one of the arguments is NaN, then the result is also NaN. + + Example: + + ```mlir + // Scalar floating-point minimum. + %a = minf %b, %c : f64 + ``` + }]; } //===----------------------------------------------------------------------===// -// IndexCastOp +// MinSIOp //===----------------------------------------------------------------------===// -def IndexCastOp : ArithmeticCastOp<"index_cast"> { - let summary = "cast between index and integer types"; +def MinSIOp : IntBinaryOp<"minsi"> { + let summary = "signed integer minimum operation"; let description = [{ - Casts between scalar or vector integers and corresponding 'index' scalar or - vectors. Index is an integer of platform-specific bit width. If casting to - a wider integer, the value is sign-extended. If casting to a narrower - integer, the value is truncated. - }]; - - let hasFolder = 1; - let hasCanonicalizer = 1; -} - -//===----------------------------------------------------------------------===// -// MaxFOp -//===----------------------------------------------------------------------===// - -def MaxFOp : FloatBinaryOp<"maxf"> { - let summary = "floating-point maximum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `maxf` ssa-use `,` ssa-use `:` type - ``` - - Returns the maximum of the two arguments, treating -0.0 as less than +0.0. - If one of the arguments is NaN, then the result is also NaN. - - Example: - - ```mlir - // Scalar floating-point maximum. - %a = maxf %b, %c : f64 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// MaxSIOp -//===----------------------------------------------------------------------===// - -def MaxSIOp : IntBinaryOp<"maxsi"> { - let summary = "signed integer maximum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `maxsi` ssa-use `,` ssa-use `:` type - ``` - - Returns the larger of %a and %b comparing the values as signed integers. - - Example: - - ```mlir - // Scalar signed integer maximum. - %a = maxsi %b, %c : i64 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// MaxUIOp -//===----------------------------------------------------------------------===// - -def MaxUIOp : IntBinaryOp<"maxui"> { - let summary = "unsigned integer maximum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `maxui` ssa-use `,` ssa-use `:` type - ``` - - Returns the larger of %a and %b comparing the values as unsigned integers. - - Example: - - ```mlir - // Scalar unsigned integer maximum. - %a = maxui %b, %c : i64 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// MinFOp -//===----------------------------------------------------------------------===// - -def MinFOp : FloatBinaryOp<"minf"> { - let summary = "floating-point minimum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `minf` ssa-use `,` ssa-use `:` type - ``` - - Returns the minimum of the two arguments, treating -0.0 as less than +0.0. - If one of the arguments is NaN, then the result is also NaN. - - Example: - - ```mlir - // Scalar floating-point minimum. - %a = minf %b, %c : f64 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// MinSIOp -//===----------------------------------------------------------------------===// - -def MinSIOp : IntBinaryOp<"minsi"> { - let summary = "signed integer minimum operation"; - let description = [{ - Syntax: + Syntax: ``` operation ::= ssa-id `=` `minsi` ssa-use `,` ssa-use `:` type @@ -1393,119 +849,6 @@ }]; } -//===----------------------------------------------------------------------===// -// MulFOp -//===----------------------------------------------------------------------===// - -def MulFOp : FloatBinaryOp<"mulf"> { - let summary = "floating point multiplication operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.mulf` ssa-use `,` ssa-use `:` type - ``` - - The `mulf` operation takes two operands and returns one result, each of - these is required to be the same type. This type may be a floating point - scalar type, a vector whose element type is a floating point type, or a - floating point tensor. - - Example: - - ```mlir - // Scalar multiplication. - %a = mulf %b, %c : f64 - - // SIMD pointwise vector multiplication, e.g. for Intel SSE. - %f = mulf %g, %h : vector<4xf32> - - // Tensor pointwise multiplication. - %x = mulf %y, %z : tensor<4x?xbf16> - ``` - - TODO: In the distant future, this will accept optional attributes for fast - math, contraction, rounding mode, and other controls. - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// MulIOp -//===----------------------------------------------------------------------===// - -def MulIOp : IntBinaryOp<"muli", [Commutative]> { - let summary = "integer multiplication operation"; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// NegFOp -//===----------------------------------------------------------------------===// - -def NegFOp : FloatUnaryOp<"negf"> { - let summary = "floating point negation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `negf` ssa-use `:` type - ``` - - The `negf` operation computes the negation of a given value. It takes one - operand and returns one result of the same type. This type may be a float - scalar type, a vector whose element type is float, or a tensor of floats. - It has no standard attributes. - - Example: - - ```mlir - // Scalar negation value. - %a = negf %b : f64 - - // SIMD vector element-wise negation value. - %f = negf %g : vector<4xf32> - - // Tensor element-wise negation value. - %x = negf %y : tensor<4x?xf8> - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// OrOp -//===----------------------------------------------------------------------===// - -def OrOp : IntBinaryOp<"or", [Commutative]> { - let summary = "integer binary or"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `or` ssa-use `,` ssa-use `:` type - ``` - - The `or` operation takes two operands and returns one result, each of these - is required to be the same type. This type may be an integer scalar type, a - vector whose element type is integer, or a tensor of integers. It has no - standard attributes. - - Example: - - ```mlir - // Scalar integer bitwise or. - %a = or %b, %c : i64 - - // SIMD vector element-wise bitwise integer or. - %f = or %g, %h : vector<4xi32> - - // Tensor element-wise bitwise integer or. - %x = or %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// @@ -1538,14 +881,6 @@ let assemblyFormat = "$memrefOrTensor attr-dict `:` type($memrefOrTensor)"; } -//===----------------------------------------------------------------------===// -// RemFOp -//===----------------------------------------------------------------------===// - -def RemFOp : FloatBinaryOp<"remf"> { - let summary = "floating point division remainder operation"; -} - //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// @@ -1641,236 +976,6 @@ let hasFolder = 1; } -//===----------------------------------------------------------------------===// -// ShiftLeftOp -//===----------------------------------------------------------------------===// - -def ShiftLeftOp : IntBinaryOp<"shift_left"> { - let summary = "integer left-shift"; - let description = [{ - The shift_left operation shifts an integer value to the left by a variable - amount. The low order bits are filled with zeros. - - Example: - - ```mlir - %1 = constant 5 : i8 // %1 is 0b00000101 - %2 = constant 3 : i8 - %3 = shift_left %1, %2 : (i8, i8) -> i8 // %3 is 0b00101000 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// SignedDivIOp -//===----------------------------------------------------------------------===// - -def SignedDivIOp : IntBinaryOp<"divi_signed"> { - let summary = "signed integer division operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `divi_signed` ssa-use `,` ssa-use `:` type - ``` - - Signed integer division. Rounds towards zero. Treats the leading bit as - sign, i.e. `6 / -2 = -3`. - - Note: the semantics of division by zero or signed division overflow (minimum - value divided by -1) is TBD; do NOT assume any specific behavior. - - Example: - - ```mlir - // Scalar signed integer division. - %a = divi_signed %b, %c : i64 - - // SIMD vector element-wise division. - %f = divi_signed %g, %h : vector<4xi32> - - // Tensor element-wise integer division. - %x = divi_signed %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SignedFloorDivIOp -//===----------------------------------------------------------------------===// - -def SignedFloorDivIOp : IntBinaryOp<"floordivi_signed"> { - let summary = "signed floor integer division operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `floordivi_signed` ssa-use `,` ssa-use `:` type - ``` - - Signed integer division. Rounds towards negative infinity, i.e. `5 / -2 = -3`. - - Note: the semantics of division by zero or signed division overflow (minimum - value divided by -1) is TBD; do NOT assume any specific behavior. - - Example: - - ```mlir - // Scalar signed integer division. - %a = floordivi_signed %b, %c : i64 - - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SignedCeilDivIOp -//===----------------------------------------------------------------------===// - -def SignedCeilDivIOp : IntBinaryOp<"ceildivi_signed"> { - let summary = "signed ceil integer division operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `ceildivi_signed` ssa-use `,` ssa-use `:` type - ``` - - Signed integer division. Rounds towards positive infinity, i.e. `7 / -2 = -3`. - - Note: the semantics of division by zero or signed division overflow (minimum - value divided by -1) is TBD; do NOT assume any specific behavior. - - Example: - - ```mlir - // Scalar signed integer division. - %a = ceildivi_signed %b, %c : i64 - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SignedRemIOp -//===----------------------------------------------------------------------===// - -def SignedRemIOp : IntBinaryOp<"remi_signed"> { - let summary = "signed integer division remainder operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.remi_signed` ssa-use `,` ssa-use `:` type - ``` - - Signed integer division remainder. Treats the leading bit as sign, i.e. `6 % - -2 = 0`. - - Note: the semantics of division by zero is TBD; do NOT assume any specific - behavior. - - Example: - - ```mlir - // Scalar signed integer division remainder. - %a = remi_signed %b, %c : i64 - - // SIMD vector element-wise division remainder. - %f = remi_signed %g, %h : vector<4xi32> - - // Tensor element-wise integer division remainder. - %x = remi_signed %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SignedShiftRightOp -//===----------------------------------------------------------------------===// - -def SignedShiftRightOp : IntBinaryOp<"shift_right_signed"> { - let summary = "signed integer right-shift"; - let description = [{ - The shift_right_signed operation shifts an integer value to the right by - a variable amount. The integer is interpreted as signed. The high order - bits in the output are filled with copies of the most-significant bit - of the shifted value (which means that the sign of the value is preserved). - - Example: - - ```mlir - %1 = constant 160 : i8 // %1 is 0b10100000 - %2 = constant 3 : i8 - %3 = shift_right_signed %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100 - %4 = constant 96 : i8 // %4 is 0b01100000 - %5 = shift_right_signed %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// SignExtendIOp -//===----------------------------------------------------------------------===// - -def SignExtendIOp : Std_Op<"sexti", [NoSideEffect, - DeclareOpInterfaceMethods] # - ElementwiseMappable.traits> { - let summary = "integer sign extension operation"; - let description = [{ - The integer sign extension operation takes an integer input of - width M and an integer destination type of width N. The destination - bit-width must be larger than the input bit-width (N > M). - The top-most (N - M) bits of the output are filled with copies - of the most-significant bit of the input. - - Example: - - ```mlir - %1 = constant 5 : i3 // %1 is 0b101 - %2 = sexti %1 : i3 to i6 // %2 is 0b111101 - %3 = constant 2 : i3 // %3 is 0b010 - %4 = sexti %3 : i3 to i6 // %4 is 0b000010 - - %5 = sexti %0 : vector<2 x i32> to vector<2 x i64> - ``` - }]; - - let arguments = (ins SignlessIntegerLike:$value); - let results = (outs SignlessIntegerLike); - - let builders = [ - OpBuilder<(ins "Value":$value, "Type":$destType), [{ - $_state.addOperands(value); - $_state.addTypes(destType); - }]>]; - - let parser = [{ - return impl::parseCastOp(parser, result); - }]; - let printer = [{ - return printStandardCastOp(this->getOperation(), p); - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SIToFPOp -//===----------------------------------------------------------------------===// - -def SIToFPOp : ArithmeticCastOp<"sitofp"> { - let summary = "cast from integer type to floating-point"; - let description = [{ - Cast from a value interpreted as a signed integer to the corresponding - floating-point value. If the value cannot be exactly represented, it is - rounded using the default rounding mode. When operating on vectors, casts - elementwise. - }]; -} - //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// @@ -1918,25 +1023,6 @@ let assemblyFormat = "$input attr-dict `:` type($aggregate)"; } -//===----------------------------------------------------------------------===// -// SubFOp -//===----------------------------------------------------------------------===// - -def SubFOp : FloatBinaryOp<"subf"> { - let summary = "floating point subtraction operation"; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SubIOp -//===----------------------------------------------------------------------===// - -def SubIOp : IntBinaryOp<"subi"> { - let summary = "integer subtraction operation"; - let hasFolder = 1; - let hasCanonicalizer = 1; -} - //===----------------------------------------------------------------------===// // SwitchOp //===----------------------------------------------------------------------===// @@ -2025,225 +1111,4 @@ let hasCanonicalizer = 1; } -//===----------------------------------------------------------------------===// -// TruncateIOp -//===----------------------------------------------------------------------===// - -def TruncateIOp : Std_Op<"trunci", [NoSideEffect, - DeclareOpInterfaceMethods] # - ElementwiseMappable.traits> { - let summary = "integer truncation operation"; - let description = [{ - The integer truncation operation takes an integer input of - width M and an integer destination type of width N. The destination - bit-width must be smaller than the input bit-width (N < M). - The top-most (N - M) bits of the input are discarded. - - Example: - - ```mlir - %1 = constant 21 : i5 // %1 is 0b10101 - %2 = trunci %1 : i5 to i4 // %2 is 0b0101 - %3 = trunci %1 : i5 to i3 // %3 is 0b101 - - %5 = trunci %0 : vector<2 x i32> to vector<2 x i16> - ``` - }]; - - let arguments = (ins SignlessIntegerLike:$value); - let results = (outs SignlessIntegerLike); - - let builders = [ - OpBuilder<(ins "Value":$value, "Type":$destType), [{ - $_state.addOperands(value); - $_state.addTypes(destType); - }]>]; - - let parser = [{ - return impl::parseCastOp(parser, result); - }]; - let printer = [{ - return printStandardCastOp(this->getOperation(), p); - }]; - - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// UIToFPOp -//===----------------------------------------------------------------------===// - -def UIToFPOp : ArithmeticCastOp<"uitofp"> { - let summary = "cast from unsigned integer type to floating-point"; - let description = [{ - Cast from a value interpreted as unsigned integer to the corresponding - floating-point value. If the value cannot be exactly represented, it is - rounded using the default rounding mode. When operating on vectors, casts - elementwise. - }]; -} - -//===----------------------------------------------------------------------===// -// UnsignedDivIOp -//===----------------------------------------------------------------------===// - -def UnsignedDivIOp : IntBinaryOp<"divi_unsigned"> { - let summary = "unsigned integer division operation"; - let description = [{ - Syntax: - ``` - operation ::= ssa-id `=` `std.divi_unsigned` ssa-use `,` ssa-use `:` type - ``` - - Unsigned integer division. Rounds towards zero. Treats the leading bit as - the most significant, i.e. for `i16` given two's complement representation, - `6 / -2 = 6 / (2^16 - 2) = 0`. - - Note: the semantics of division by zero is TBD; do NOT assume any specific - behavior. - - Example: - - ```mlir - // Scalar unsigned integer division. - %a = divi_unsigned %b, %c : i64 - - // SIMD vector element-wise division. - %f = divi_unsigned %g, %h : vector<4xi32> - - // Tensor element-wise integer division. - %x = divi_unsigned %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// UnsignedRemIOp -//===----------------------------------------------------------------------===// - -def UnsignedRemIOp : IntBinaryOp<"remi_unsigned"> { - let summary = "unsigned integer division remainder operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.remi_unsigned` ssa-use `,` ssa-use `:` type - ``` - - Unsigned integer division remainder. Treats the leading bit as the most - significant, i.e. for `i16`, `6 % -2 = 6 % (2^16 - 2) = 6`. - - Note: the semantics of division by zero is TBD; do NOT assume any specific - behavior. - - Example: - - ```mlir - // Scalar unsigned integer division remainder. - %a = remi_unsigned %b, %c : i64 - - // SIMD vector element-wise division remainder. - %f = remi_unsigned %g, %h : vector<4xi32> - - // Tensor element-wise integer division remainder. - %x = remi_unsigned %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// UnsignedShiftRightOp -//===----------------------------------------------------------------------===// - -def UnsignedShiftRightOp : IntBinaryOp<"shift_right_unsigned"> { - let summary = "unsigned integer right-shift"; - let description = [{ - The shift_right_unsigned operation shifts an integer value to the right by - a variable amount. The integer is interpreted as unsigned. The high order - bits are always filled with zeros. - - Example: - - ```mlir - %1 = constant 160 : i8 // %1 is 0b10100000 - %2 = constant 3 : i8 - %3 = shift_right_unsigned %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// XOrOp -//===----------------------------------------------------------------------===// - -def XOrOp : IntBinaryOp<"xor", [Commutative]> { - let summary = "integer binary xor"; - let description = [{ - The `xor` operation takes two operands and returns one result, each of these - is required to be the same type. This type may be an integer scalar type, a - vector whose element type is integer, or a tensor of integers. It has no - standard attributes. - - Example: - - ```mlir - // Scalar integer bitwise xor. - %a = xor %b, %c : i64 - - // SIMD vector element-wise bitwise integer xor. - %f = xor %g, %h : vector<4xi32> - - // Tensor element-wise bitwise integer xor. - %x = xor %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; - let hasCanonicalizer = 1; -} - -//===----------------------------------------------------------------------===// -// ZeroExtendIOp -//===----------------------------------------------------------------------===// - -def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, - DeclareOpInterfaceMethods] # - ElementwiseMappable.traits> { - let summary = "integer zero extension operation"; - let description = [{ - The integer zero extension operation takes an integer input of - width M and an integer destination type of width N. The destination - bit-width must be larger than the input bit-width (N > M). - The top-most (N - M) bits of the output are filled with zeros. - - Example: - - ```mlir - %1 = constant 5 : i3 // %1 is 0b101 - %2 = zexti %1 : i3 to i6 // %2 is 0b000101 - %3 = constant 2 : i3 // %3 is 0b010 - %4 = zexti %3 : i3 to i6 // %4 is 0b000010 - - %5 = zexti %0 : vector<2 x i32> to vector<2 x i64> - ``` - }]; - - let arguments = (ins SignlessIntegerLike:$value); - let results = (outs SignlessIntegerLike); - - let builders = [ - OpBuilder<(ins "Value":$value, "Type":$destType), [{ - $_state.addOperands(value); - $_state.addTypes(destType); - }]>]; - - let parser = [{ - return impl::parseCastOp(parser, result); - }]; - let printer = [{ - return printStandardCastOp(this->getOperation(), p); - }]; -} - #endif // STANDARD_OPS diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td @@ -36,50 +36,4 @@ let cppNamespace = "::mlir"; } -// The predicate indicates the type of the comparison to perform: -// (un)orderedness, (in)equality and less/greater than (or equal to) as -// well as predicates that are always true or false. -def CMPF_P_FALSE : I64EnumAttrCase<"AlwaysFalse", 0, "false">; -def CMPF_P_OEQ : I64EnumAttrCase<"OEQ", 1, "oeq">; -def CMPF_P_OGT : I64EnumAttrCase<"OGT", 2, "ogt">; -def CMPF_P_OGE : I64EnumAttrCase<"OGE", 3, "oge">; -def CMPF_P_OLT : I64EnumAttrCase<"OLT", 4, "olt">; -def CMPF_P_OLE : I64EnumAttrCase<"OLE", 5, "ole">; -def CMPF_P_ONE : I64EnumAttrCase<"ONE", 6, "one">; -def CMPF_P_ORD : I64EnumAttrCase<"ORD", 7, "ord">; -def CMPF_P_UEQ : I64EnumAttrCase<"UEQ", 8, "ueq">; -def CMPF_P_UGT : I64EnumAttrCase<"UGT", 9, "ugt">; -def CMPF_P_UGE : I64EnumAttrCase<"UGE", 10, "uge">; -def CMPF_P_ULT : I64EnumAttrCase<"ULT", 11, "ult">; -def CMPF_P_ULE : I64EnumAttrCase<"ULE", 12, "ule">; -def CMPF_P_UNE : I64EnumAttrCase<"UNE", 13, "une">; -def CMPF_P_UNO : I64EnumAttrCase<"UNO", 14, "uno">; -def CMPF_P_TRUE : I64EnumAttrCase<"AlwaysTrue", 15, "true">; - -def CmpFPredicateAttr : I64EnumAttr< - "CmpFPredicate", "", - [CMPF_P_FALSE, CMPF_P_OEQ, CMPF_P_OGT, CMPF_P_OGE, CMPF_P_OLT, CMPF_P_OLE, - CMPF_P_ONE, CMPF_P_ORD, CMPF_P_UEQ, CMPF_P_UGT, CMPF_P_UGE, CMPF_P_ULT, - CMPF_P_ULE, CMPF_P_UNE, CMPF_P_UNO, CMPF_P_TRUE]> { - let cppNamespace = "::mlir"; -} - -def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>; -def CMPI_P_NE : I64EnumAttrCase<"ne", 1>; -def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>; -def CMPI_P_SLE : I64EnumAttrCase<"sle", 3>; -def CMPI_P_SGT : I64EnumAttrCase<"sgt", 4>; -def CMPI_P_SGE : I64EnumAttrCase<"sge", 5>; -def CMPI_P_ULT : I64EnumAttrCase<"ult", 6>; -def CMPI_P_ULE : I64EnumAttrCase<"ule", 7>; -def CMPI_P_UGT : I64EnumAttrCase<"ugt", 8>; -def CMPI_P_UGE : I64EnumAttrCase<"uge", 9>; - -def CmpIPredicateAttr : I64EnumAttr< - "CmpIPredicate", "", - [CMPI_P_EQ, CMPI_P_NE, CMPI_P_SLT, CMPI_P_SLE, CMPI_P_SGT, - CMPI_P_SGE, CMPI_P_ULT, CMPI_P_ULE, CMPI_P_UGT, CMPI_P_UGE]> { - let cppNamespace = "::mlir"; -} - #endif // STANDARD_OPS_BASE diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -43,7 +43,7 @@ /// Creates an instance of the StdExpand pass that legalizes Std /// dialect ops to be convertible to LLVM. For example, -/// `std.ceildivi_signed` gets transformed to a number of std operations, +/// `std.arith.ceildivsi` gets transformed to a number of std operations, /// which can be lowered to LLVM; `memref.reshape` gets converted to /// `memref_reinterpret_cast`. std::unique_ptr createStdExpandOpsPass(); diff --git a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h --- a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h @@ -16,6 +16,7 @@ #ifndef MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H #define MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -24,7 +25,7 @@ namespace mlir { /// Matches a ConstantIndexOp. -detail::op_matcher matchConstantIndex(); +detail::op_matcher matchConstantIndex(); /// Detects the `values` produced by a ConstantIndexOp and places the new /// constant in place of the corresponding sentinel value. diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_TENSOR_IR_TENSOR_H_ #define MLIR_DIALECT_TENSOR_IR_TENSOR_H_ +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td @@ -46,6 +46,7 @@ }]; let hasConstantMaterializer = 1; + let dependentDialects = ["arith::ArithmeticDialect"]; } #endif // TENSOR_BASE diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -21,6 +21,7 @@ let name = "vector"; let cppNamespace = "::mlir::vector"; let hasConstantMaterializer = 1; + let dependentDialects = ["arith::ArithmeticDialect"]; } // Base class for Vector dialect ops. @@ -579,7 +580,7 @@ %idx0 = ... : index // dynamic computation producing the value 1 of index type %idx1 = ... : index - %0 = constant dense<0, 1, 2, 3>: vector<4xi32> + %0 = arith.constant dense<0, 1, 2, 3>: vector<4xi32> // extracts values [0, 1] %1 = vector.extract_map %0[%idx0] : vector<4xi32> to vector<2xi32> // extracts values [1, 2] @@ -767,7 +768,7 @@ %idx0 = ... : index // dynamic computation producing the value 1 of index type %idx1 = ... : index / - %0 = constant dense<0, 1, 2, 3>: vector<4xi32> + %0 = arith.constant dense<0, 1, 2, 3>: vector<4xi32> // extracts values [0, 1] %1 = vector.extract_map %0[%idx0] : vector<4xi32> to vector<2xi32> // extracts values [1, 2] diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -173,9 +173,9 @@ /// canonicalizations pattern to propagate and fold the vector /// insert_map/extract_map operations. /// Transforms: -// %v = addf %a, %b : vector<32xf32> +// %v = arith.addf %a, %b : vector<32xf32> /// to: -/// %v = addf %a, %b : vector<32xf32> +/// %v = arith.addf %a, %b : vector<32xf32> /// %ev = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> /// %nv = vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32> Optional diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -325,7 +325,7 @@ %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> %1 = vector.extractelement %0[%i0 : i32]: vector<8xf32> %2 = vector.extractelement %0[%i4 : i32]: vector<8xf32> - %d = addf %1, %2 : f32 + %d = arith.addf %1, %2 : f32 ``` }]; let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a, diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1383,11 +1383,11 @@ /// /// Examples: /// ``` -/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32 +/// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32 /// ``` /// can be tensorized to /// ``` -/// %tensor = "std.addf"(%a, %b) : (tensor, tensor) +/// %tensor = "arith.addf"(%a, %b) : (tensor, tensor) /// -> tensor /// ``` /// diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/AMX/AMXDialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/Async/IR/Async.h" @@ -52,6 +53,7 @@ // clang-format off registry.insert - %3 = addf %2, %2 : f32 + %3 = arith.addf %2, %2 : f32 affine.store %3, %arg0[%arg2] : memref<10xf32> } affine.for %arg2 = 0 to 10 { %2 = affine.load %1[%arg2] : memref<10xf32> - %3 = mulf %2, %2 : f32 + %3 = arith.mulf %2, %2 : f32 affine.store %3, %arg1[%arg2] : memref<10xf32> } return @@ -67,10 +67,10 @@ affine.store %cst, %0[0] : memref<1xf32> affine.store %cst, %1[0] : memref<1xf32> %2 = affine.load %1[0] : memref<1xf32> - %3 = mulf %2, %2 : f32 + %3 = arith.mulf %2, %2 : f32 affine.store %3, %arg1[%arg2] : memref<10xf32> %4 = affine.load %0[0] : memref<1xf32> - %5 = addf %4, %4 : f32 + %5 = arith.addf %4, %4 : f32 affine.store %5, %arg0[%arg2] : memref<10xf32> } return @@ -87,7 +87,7 @@ affine.for %arg6 = 0 to 3 { %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32> - %2 = mulf %0, %1 : f32 + %2 = arith.mulf %0, %1 : f32 affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32> } } @@ -95,7 +95,7 @@ affine.for %arg6 = 0 to 3 { %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> %1 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32> - %2 = addf %0, %1 : f32 + %2 = arith.addf %0, %1 : f32 affine.store %2, %arg4[%arg5, %arg6] : memref<10x10xf32> } } @@ -111,11 +111,11 @@ affine.for %arg6 = 0 to 3 { %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32> - %2 = mulf %0, %1 : f32 + %2 = arith.mulf %0, %1 : f32 affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32> %3 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> %4 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32> - %5 = addf %3, %4 : f32 + %5 = arith.addf %3, %4 : f32 affine.store %5, %arg4[%arg5, %arg6] : memref<10x10xf32> } } @@ -481,6 +481,7 @@ let summary = "Coalesce nested loops with independent bounds into a single " "loop"; let constructor = "mlir::createLoopCoalescingPass()"; + let dependentDialects = ["arith::ArithmeticDialect"]; } def LoopInvariantCodeMotion : Pass<"loop-invariant-code-motion"> { @@ -524,7 +525,7 @@ %B: index, %C: memref<16xf64>) -> (memref<16xf64, #tile>) { affine.for %arg3 = 0 to 16 { %a = affine.load %A[%arg3] : memref<16xf64, #tile> - %p = mulf %a, %a : f64 + %p = arith.mulf %a, %a : f64 affine.store %p, %A[%arg3] : memref<16xf64, #tile> } %c = alloc() : memref<16xf64, #tile> @@ -540,7 +541,7 @@ -> memref<4x4xf64> { affine.for %arg3 = 0 to 16 { %3 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64> - %4 = mulf %3, %3 : f64 + %4 = arith.mulf %3, %3 : f64 affine.store %4, %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64> } %0 = alloc() : memref<4x4xf64> @@ -566,8 +567,8 @@ %0 = affine.load %arg0[%arg3, %arg5] : memref<8x8xi32, #linear8> %1 = affine.load %arg1[%arg5, %arg4] : memref<8x8xi32, #linear8> %2 = affine.load %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8> - %3 = muli %0, %1 : i32 - %4 = addi %2, %3 : i32 + %3 = arith.muli %0, %1 : i32 + %4 = arith.addi %2, %3 : i32 affine.store %4, %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8> } } @@ -590,8 +591,8 @@ %0 = affine.load %arg0[%arg3 * 8 + %arg5] : memref<64xi32> %1 = affine.load %arg1[%arg5 * 8 + %arg4] : memref<64xi32> %2 = affine.load %arg2[%arg3 * 8 + %arg4] : memref<64xi32> - %3 = muli %0, %1 : i32 - %4 = addi %2, %3 : i32 + %3 = arith.muli %0, %1 : i32 + %4 = arith.addi %2, %3 : i32 affine.store %4, %arg2[%arg3 * 8 + %arg4] : memref<64xi32> } } diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -17,6 +17,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/BuiltinOps.h" @@ -53,10 +54,10 @@ Operation *combinerOp = combinerOps.back(); Optional maybeKind = TypeSwitch>(combinerOp) - .Case([](Operation *) { return AtomicRMWKind::addf; }) - .Case([](Operation *) { return AtomicRMWKind::mulf; }) - .Case([](Operation *) { return AtomicRMWKind::addi; }) - .Case([](Operation *) { return AtomicRMWKind::muli; }) + .Case([](arith::AddFOp) { return AtomicRMWKind::addf; }) + .Case([](arith::MulFOp) { return AtomicRMWKind::mulf; }) + .Case([](arith::AddIOp) { return AtomicRMWKind::addi; }) + .Case([](arith::MulIOp) { return AtomicRMWKind::muli; }) .Default([](Operation *) -> Optional { // TODO: AtomicRMW supports other kinds of reductions this is // currently not detecting, add those when the need arises. @@ -640,10 +641,9 @@ auto symbol = operands[i]; assert(isValidSymbol(symbol)); // Check if the symbol is a constant. - if (auto cOp = symbol.getDefiningOp()) + if (auto cOp = symbol.getDefiningOp()) dependenceDomain->addBound(FlatAffineConstraints::EQ, - valuePosMap.getSymPos(symbol), - cOp.getValue()); + valuePosMap.getSymPos(symbol), cOp.value()); } }; diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -15,6 +15,7 @@ #include "mlir/Analysis/Presburger/Simplex.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/IntegerSet.h" @@ -654,8 +655,8 @@ // Add top level symbol. appendSymbolId(val); // Check if the symbol is a constant. - if (auto constOp = val.getDefiningOp()) - addBound(BoundType::EQ, val, constOp.getValue()); + if (auto constOp = val.getDefiningOp()) + addBound(BoundType::EQ, val, constOp.value()); } LogicalResult diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -37,13 +37,12 @@ mlir-headers LINK_LIBS PUBLIC - MLIRAffine MLIRCallInterfaces MLIRControlFlowInterfaces MLIRDataLayoutInterfaces MLIRInferTypeOpInterface - MLIRLinalg - MLIRSCF + MLIRSideEffectInterfaces + MLIRViewLikeInterface ) add_mlir_library(MLIRLoopAnalysis diff --git a/mlir/lib/Analysis/NumberOfExecutions.cpp b/mlir/lib/Analysis/NumberOfExecutions.cpp --- a/mlir/lib/Analysis/NumberOfExecutions.cpp +++ b/mlir/lib/Analysis/NumberOfExecutions.cpp @@ -11,7 +11,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/NumberOfExecutions.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/RegionKindInterface.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -11,9 +11,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -17,6 +17,7 @@ #include "mlir/Analysis/PresburgerSet.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/IntegerSet.h" #include "llvm/ADT/SmallPtrSet.h" @@ -98,8 +99,8 @@ assert(cst->containsId(value) && "value expected to be present"); if (isValidSymbol(value)) { // Check if the symbol is a constant. - if (auto cOp = value.getDefiningOp()) - cst->addBound(FlatAffineConstraints::EQ, value, cOp.getValue()); + if (auto cOp = value.getDefiningOp()) + cst->addBound(FlatAffineConstraints::EQ, value, cOp.value()); } else if (auto loop = getForInductionVarOwner(value)) { if (failed(cst->addAffineForOpDomain(loop))) return failure(); @@ -517,8 +518,8 @@ assert(isValidSymbol(symbol)); // Check if the symbol is a constant. if (auto *op = symbol.getDefiningOp()) { - if (auto constOp = dyn_cast(op)) { - cst.addBound(FlatAffineConstraints::EQ, symbol, constOp.getValue()); + if (auto constOp = dyn_cast(op)) { + cst.addBound(FlatAffineConstraints::EQ, symbol, constOp.value()); } } } diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -15,6 +15,7 @@ #include "../PassDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -56,11 +57,11 @@ } Value visitAddExpr(AffineBinaryOpExpr expr) { - return buildBinaryExpr(expr); + return buildBinaryExpr(expr); } Value visitMulExpr(AffineBinaryOpExpr expr) { - return buildBinaryExpr(expr); + return buildBinaryExpr(expr); } /// Euclidean modulo operation: negative RHS is not allowed. @@ -89,11 +90,12 @@ auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value remainder = builder.create(loc, lhs, rhs); - Value zeroCst = builder.create(loc, 0); - Value isRemainderNegative = - builder.create(loc, CmpIPredicate::slt, remainder, zeroCst); - Value correctedRemainder = builder.create(loc, remainder, rhs); + Value remainder = builder.create(loc, lhs, rhs); + Value zeroCst = builder.create(loc, 0); + Value isRemainderNegative = builder.create( + loc, arith::CmpIPredicate::slt, remainder, zeroCst); + Value correctedRemainder = + builder.create(loc, remainder, rhs); Value result = builder.create(loc, isRemainderNegative, correctedRemainder, remainder); return result; @@ -126,15 +128,16 @@ auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value zeroCst = builder.create(loc, 0); - Value noneCst = builder.create(loc, -1); - Value negative = - builder.create(loc, CmpIPredicate::slt, lhs, zeroCst); - Value negatedDecremented = builder.create(loc, noneCst, lhs); + Value zeroCst = builder.create(loc, 0); + Value noneCst = builder.create(loc, -1); + Value negative = builder.create( + loc, arith::CmpIPredicate::slt, lhs, zeroCst); + Value negatedDecremented = builder.create(loc, noneCst, lhs); Value dividend = builder.create(loc, negative, negatedDecremented, lhs); - Value quotient = builder.create(loc, dividend, rhs); - Value correctedQuotient = builder.create(loc, noneCst, quotient); + Value quotient = builder.create(loc, dividend, rhs); + Value correctedQuotient = + builder.create(loc, noneCst, quotient); Value result = builder.create(loc, negative, correctedQuotient, quotient); return result; @@ -165,27 +168,26 @@ auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value zeroCst = builder.create(loc, 0); - Value oneCst = builder.create(loc, 1); - Value nonPositive = - builder.create(loc, CmpIPredicate::sle, lhs, zeroCst); - Value negated = builder.create(loc, zeroCst, lhs); - Value decremented = builder.create(loc, lhs, oneCst); + Value zeroCst = builder.create(loc, 0); + Value oneCst = builder.create(loc, 1); + Value nonPositive = builder.create( + loc, arith::CmpIPredicate::sle, lhs, zeroCst); + Value negated = builder.create(loc, zeroCst, lhs); + Value decremented = builder.create(loc, lhs, oneCst); Value dividend = builder.create(loc, nonPositive, negated, decremented); - Value quotient = builder.create(loc, dividend, rhs); - Value negatedQuotient = builder.create(loc, zeroCst, quotient); - Value incrementedQuotient = builder.create(loc, quotient, oneCst); + Value quotient = builder.create(loc, dividend, rhs); + Value negatedQuotient = + builder.create(loc, zeroCst, quotient); + Value incrementedQuotient = + builder.create(loc, quotient, oneCst); Value result = builder.create(loc, nonPositive, negatedQuotient, incrementedQuotient); return result; } Value visitConstantExpr(AffineConstantExpr expr) { - auto valueAttr = - builder.getIntegerAttr(builder.getIndexType(), expr.getValue()); - auto op = - builder.create(loc, builder.getIndexType(), valueAttr); + auto op = builder.create(loc, expr.getValue()); return op.getResult(); } @@ -242,20 +244,21 @@ /// comparison to perform, "lt" for "min", "gt" for "max" and is used for the /// `cmpi` operation followed by the `select` operation: /// -/// %cond = cmpi "predicate" %v0, %v1 +/// %cond = arith.cmpi "predicate" %v0, %v1 /// %result = select %cond, %v0, %v1 /// /// Multiple values are scanned in a linear sequence. This creates a data /// dependences that wouldn't exist in a tree reduction, but is easier to /// recognize as a reduction by the subsequent passes. -static Value buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, +static Value buildMinMaxReductionSeq(Location loc, + arith::CmpIPredicate predicate, ValueRange values, OpBuilder &builder) { assert(!llvm::empty(values) && "empty min/max chain"); auto valueIt = values.begin(); Value value = *valueIt++; for (; valueIt != values.end(); ++valueIt) { - auto cmpOp = builder.create(loc, predicate, value, *valueIt); + auto cmpOp = builder.create(loc, predicate, value, *valueIt); value = builder.create(loc, cmpOp.getResult(), value, *valueIt); } @@ -267,7 +270,8 @@ static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands) { if (auto values = expandAffineMap(builder, loc, map, operands)) - return buildMinMaxReductionSeq(loc, CmpIPredicate::sgt, *values, builder); + return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::sgt, *values, + builder); return nullptr; } @@ -276,7 +280,8 @@ static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands) { if (auto values = expandAffineMap(builder, loc, map, operands)) - return buildMinMaxReductionSeq(loc, CmpIPredicate::slt, *values, builder); + return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::slt, *values, + builder); return nullptr; } @@ -356,7 +361,7 @@ Location loc = op.getLoc(); Value lowerBound = lowerAffineLowerBound(op, rewriter); Value upperBound = lowerAffineUpperBound(op, rewriter); - Value step = rewriter.create(loc, op.getStep()); + Value step = rewriter.create(loc, op.getStep()); auto scfForOp = rewriter.create(loc, lowerBound, upperBound, step, op.getIterOperands()); rewriter.eraseBlock(scfForOp.getBody()); @@ -399,7 +404,7 @@ } steps.reserve(op.steps().size()); for (Attribute step : op.steps()) - steps.push_back(rewriter.create( + steps.push_back(rewriter.create( loc, step.cast().getInt())); // Get the terminator op. @@ -475,7 +480,7 @@ // Now we just have to handle the condition logic. auto integerSet = op.getIntegerSet(); - Value zeroConstant = rewriter.create(loc, 0); + Value zeroConstant = rewriter.create(loc, 0); SmallVector operands(op.getOperands()); auto operandsRef = llvm::makeArrayRef(operands); @@ -492,14 +497,17 @@ operandsRef.drop_front(numDims)); if (!affResult) return failure(); - auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge; + auto pred = + isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge; Value cmpVal = - rewriter.create(loc, pred, affResult, zeroConstant); - cond = - cond ? rewriter.create(loc, cond, cmpVal).getResult() : cmpVal; + rewriter.create(loc, pred, affResult, zeroConstant); + cond = cond + ? rewriter.create(loc, cond, cmpVal).getResult() + : cmpVal; } cond = cond ? cond - : rewriter.create(loc, /*value=*/1, /*width=*/1); + : rewriter.create(loc, /*value=*/1, + /*width=*/1); bool hasElseRegion = !op.elseRegion().empty(); auto ifOp = rewriter.create(loc, op.getResultTypes(), cond, @@ -750,8 +758,9 @@ populateAffineToStdConversionPatterns(patterns); populateAffineToVectorConversionPatterns(patterns); ConversionTarget target(getContext()); - target.addLegalDialect(); + target + .addLegalDialect(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt b/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRAffine + MLIRArithmetic MLIRMemRef MLIRSCF MLIRPass diff --git a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp @@ -0,0 +1,304 @@ +//===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" +#include "../PassDetail.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +namespace { + +//===----------------------------------------------------------------------===// +// Straightforward Op Lowerings +//===----------------------------------------------------------------------===// + +using AddIOpLowering = VectorConvertToLLVMPattern; +using SubIOpLowering = VectorConvertToLLVMPattern; +using MulIOpLowering = VectorConvertToLLVMPattern; +using DivUIOpLowering = + VectorConvertToLLVMPattern; +using DivSIOpLowering = + VectorConvertToLLVMPattern; +using RemUIOpLowering = + VectorConvertToLLVMPattern; +using RemSIOpLowering = + VectorConvertToLLVMPattern; +using AndIOpLowering = VectorConvertToLLVMPattern; +using OrIOpLowering = VectorConvertToLLVMPattern; +using XOrIOpLowering = VectorConvertToLLVMPattern; +using ShLIOpLowering = VectorConvertToLLVMPattern; +using ShRUIOpLowering = + VectorConvertToLLVMPattern; +using ShRSIOpLowering = + VectorConvertToLLVMPattern; +using NegFOpLowering = VectorConvertToLLVMPattern; +using AddFOpLowering = VectorConvertToLLVMPattern; +using SubFOpLowering = VectorConvertToLLVMPattern; +using MulFOpLowering = VectorConvertToLLVMPattern; +using DivFOpLowering = VectorConvertToLLVMPattern; +using RemFOpLowering = VectorConvertToLLVMPattern; +using ExtUIOpLowering = + VectorConvertToLLVMPattern; +using ExtSIOpLowering = + VectorConvertToLLVMPattern; +using ExtFOpLowering = VectorConvertToLLVMPattern; +using TruncIOpLowering = + VectorConvertToLLVMPattern; +using TruncFOpLowering = + VectorConvertToLLVMPattern; +using UIToFPOpLowering = + VectorConvertToLLVMPattern; +using SIToFPOpLowering = + VectorConvertToLLVMPattern; +using FPToUIOpLowering = + VectorConvertToLLVMPattern; +using FPToSIOpLowering = + VectorConvertToLLVMPattern; +using BitcastOpLowering = + VectorConvertToLLVMPattern; + +//===----------------------------------------------------------------------===// +// Op Lowering Patterns +//===----------------------------------------------------------------------===// + +/// Directly lower to LLVM op. +struct ConstantOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// The lowering of index_cast becomes an integer conversion since index +/// becomes an integer. If the bit width of the source and target integer +/// types is the same, just erase the cast. If the target type is wider, +/// sign-extend the value, otherwise truncate it. +struct IndexCastOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +struct CmpIOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +struct CmpFOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ConstantOpLowering +//===----------------------------------------------------------------------===// + +LogicalResult +ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), + adaptor.getOperands(), + *getTypeConverter(), rewriter); +} + +//===----------------------------------------------------------------------===// +// IndexCastOpLowering +//===----------------------------------------------------------------------===// + +LogicalResult IndexCastOpLowering::matchAndRewrite( + arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto targetType = typeConverter->convertType(op.getResult().getType()); + auto targetElementType = + typeConverter->convertType(getElementTypeOrSelf(op.getResult())) + .cast(); + auto sourceElementType = + getElementTypeOrSelf(adaptor.in()).cast(); + unsigned targetBits = targetElementType.getWidth(); + unsigned sourceBits = sourceElementType.getWidth(); + + if (targetBits == sourceBits) + rewriter.replaceOp(op, adaptor.in()); + else if (targetBits < sourceBits) + rewriter.replaceOpWithNewOp(op, targetType, adaptor.in()); + else + rewriter.replaceOpWithNewOp(op, targetType, adaptor.in()); + return success(); +} + +//===----------------------------------------------------------------------===// +// CmpIOpLowering +//===----------------------------------------------------------------------===// + +// Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums +// share numerical values so just cast. +template +static LLVMPredType convertCmpPredicate(PredType pred) { + return static_cast(pred); +} + +LogicalResult +CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto operandType = adaptor.lhs().getType(); + auto resultType = op.getResult().getType(); + + // Handle the scalar and 1D vector cases. + if (!operandType.isa()) { + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(resultType), + convertCmpPredicate(op.getPredicate()), + adaptor.lhs(), adaptor.rhs()); + return success(); + } + + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return rewriter.notifyMatchFailure(op, "expected vector result type"); + + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) { + OpAdaptor adaptor(operands); + return rewriter.create( + op.getLoc(), llvm1DVectorTy, + convertCmpPredicate(op.getPredicate()), + adaptor.lhs(), adaptor.rhs()); + }, + rewriter); + + return success(); +} + +//===----------------------------------------------------------------------===// +// CmpFOpLowering +//===----------------------------------------------------------------------===// + +LogicalResult +CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto operandType = adaptor.lhs().getType(); + auto resultType = op.getResult().getType(); + + // Handle the scalar and 1D vector cases. + if (!operandType.isa()) { + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(resultType), + convertCmpPredicate(op.getPredicate()), + adaptor.lhs(), adaptor.rhs()); + return success(); + } + + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return rewriter.notifyMatchFailure(op, "expected vector result type"); + + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) { + OpAdaptor adaptor(operands); + return rewriter.create( + op.getLoc(), llvm1DVectorTy, + convertCmpPredicate(op.getPredicate()), + adaptor.lhs(), adaptor.rhs()); + }, + rewriter); +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace { +struct ConvertArithmeticToLLVMPass + : public ConvertArithmeticToLLVMBase { + ConvertArithmeticToLLVMPass() = default; + + void runOnFunction() override { + LLVMConversionTarget target(getContext()); + RewritePatternSet patterns(&getContext()); + + LowerToLLVMOptions options(&getContext()); + if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) + options.overrideIndexBitwidth(indexBitwidth); + + LLVMTypeConverter converter(&getContext(), options); + mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, + patterns); + + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// + +void mlir::arith::populateArithmeticToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + // clang-format off + patterns.add< + ConstantOpLowering, + AddIOpLowering, + SubIOpLowering, + MulIOpLowering, + DivUIOpLowering, + DivSIOpLowering, + RemUIOpLowering, + RemSIOpLowering, + AndIOpLowering, + OrIOpLowering, + XOrIOpLowering, + ShLIOpLowering, + ShRUIOpLowering, + ShRSIOpLowering, + NegFOpLowering, + AddFOpLowering, + SubFOpLowering, + MulFOpLowering, + DivFOpLowering, + RemFOpLowering, + ExtUIOpLowering, + ExtSIOpLowering, + ExtFOpLowering, + TruncIOpLowering, + TruncFOpLowering, + UIToFPOpLowering, + SIToFPOpLowering, + FPToUIOpLowering, + FPToSIOpLowering, + IndexCastOpLowering, + BitcastOpLowering, + CmpIOpLowering, + CmpFOpLowering + >(converter); + // clang-format on +} + +std::unique_ptr mlir::arith::createConvertArithmeticToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/ArithmeticToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArithmeticToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArithmeticToLLVM/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRArithmeticToLLVM + ArithmeticToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithmeticToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRLLVMCommonConversion + MLIRLLVMIR + ) diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -0,0 +1,826 @@ +//===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" +#include "../PassDetail.h" +#include "../SPIRVCommon/Pattern.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "arith-to-spirv-pattern" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Operation Conversion +//===----------------------------------------------------------------------===// + +namespace { + +/// Converts composite arith.constant operation to spv.Constant. +struct ConstantCompositeOpPattern final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts scalar arith.constant operation to spv.Constant. +struct ConstantScalarOpPattern final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.remsi to SPIR-V ops. +/// +/// This cannot be merged into the template unary/binary pattern due to Vulkan +/// restrictions over spv.SRem and spv.SMod. +struct RemSIOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts bitwise operations to SPIR-V operations. This is a special pattern +/// other than the BinaryOpPatternPattern because if the operands are boolean +/// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For +/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. +template +struct BitwiseOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.xori to SPIR-V operations. +struct XOrIOpLogicalPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.xori to SPIR-V operations if the type of source is i1 or +/// vector of i1. +struct XOrIOpBooleanPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of +/// i1. +struct UIToFPI1Pattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.extui to spv.Select if the type of source is i1 or vector of +/// i1. +struct ExtUII1Pattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.trunci to spv.Select if the type of result is i1 or vector of +/// i1. +struct TruncII1Pattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts type-casting standard operations to SPIR-V operations. +template +struct TypeCastingOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts integer compare operation on i1 type operands to SPIR-V ops. +class CmpIOpBooleanPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts integer compare operation to SPIR-V ops. +class CmpIOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts floating-point comparison operations to SPIR-V ops. +class CmpFOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts floating point NaN check to SPIR-V ops. This pattern requires +/// Kernel capability. +class CmpFOpNanKernelPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts floating point NaN check to SPIR-V ops. This pattern does not +/// require additional capability. +class CmpFOpNanNonePattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Conversion Helpers +//===----------------------------------------------------------------------===// + +/// Converts the given `srcAttr` into a boolean attribute if it holds an +/// integral value. Returns null attribute if conversion fails. +static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { + if (auto boolAttr = srcAttr.dyn_cast()) + return boolAttr; + if (auto intAttr = srcAttr.dyn_cast()) + return builder.getBoolAttr(intAttr.getValue().getBoolValue()); + return BoolAttr(); +} + +/// Converts the given `srcAttr` to a new attribute of the given `dstType`. +/// Returns null attribute if conversion fails. +static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, + Builder builder) { + // If the source number uses less active bits than the target bitwidth, then + // it should be safe to convert. + if (srcAttr.getValue().isIntN(dstType.getWidth())) + return builder.getIntegerAttr(dstType, srcAttr.getInt()); + + // XXX: Try again by interpreting the source number as a signed value. + // Although integers in the standard dialect are signless, they can represent + // a signed number. It's the operation decides how to interpret. This is + // dangerous, but it seems there is no good way of handling this if we still + // want to change the bitwidth. Emit a message at least. + if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { + auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); + LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" + << dstAttr << "' for type '" << dstType << "'\n"); + return dstAttr; + } + + LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr + << "' illegal: cannot fit into target type '" + << dstType << "'\n"); + return IntegerAttr(); +} + +/// Converts the given `srcAttr` to a new attribute of the given `dstType`. +/// Returns null attribute if `dstType` is not 32-bit or conversion fails. +static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, + Builder builder) { + // Only support converting to float for now. + if (!dstType.isF32()) + return FloatAttr(); + + // Try to convert the source floating-point number to single precision. + APFloat dstVal = srcAttr.getValue(); + bool losesInfo = false; + APFloat::opStatus status = + dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); + if (status != APFloat::opOK || losesInfo) { + LLVM_DEBUG(llvm::dbgs() + << srcAttr << " illegal: cannot fit into converted type '" + << dstType << "'\n"); + return FloatAttr(); + } + + return builder.getF32FloatAttr(dstVal.convertToFloat()); +} + +/// Returns true if the given `type` is a boolean scalar or vector type. +static bool isBoolScalarOrVector(Type type) { + if (type.isInteger(1)) + return true; + if (auto vecType = type.dyn_cast()) + return vecType.getElementType().isInteger(1); + return false; +} + +//===----------------------------------------------------------------------===// +// ConstantOp with composite type +//===----------------------------------------------------------------------===// + +LogicalResult ConstantCompositeOpPattern::matchAndRewrite( + arith::ConstantOp constOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto srcType = constOp.getType().dyn_cast(); + if (!srcType) + return failure(); + + // std.constant should only have vector or tenor types. + assert((srcType.isa())); + + auto dstType = getTypeConverter()->convertType(srcType); + if (!dstType) + return failure(); + + auto dstElementsAttr = constOp.value().dyn_cast(); + ShapedType dstAttrType = dstElementsAttr.getType(); + if (!dstElementsAttr) + return failure(); + + // If the composite type has more than one dimensions, perform linearization. + if (srcType.getRank() > 1) { + if (srcType.isa()) { + dstAttrType = RankedTensorType::get(srcType.getNumElements(), + srcType.getElementType()); + dstElementsAttr = dstElementsAttr.reshape(dstAttrType); + } else { + // TODO: add support for large vectors. + return failure(); + } + } + + Type srcElemType = srcType.getElementType(); + Type dstElemType; + // Tensor types are converted to SPIR-V array types; vector types are + // converted to SPIR-V vector/array types. + if (auto arrayType = dstType.dyn_cast()) + dstElemType = arrayType.getElementType(); + else + dstElemType = dstType.cast().getElementType(); + + // If the source and destination element types are different, perform + // attribute conversion. + if (srcElemType != dstElemType) { + SmallVector elements; + if (srcElemType.isa()) { + for (FloatAttr srcAttr : dstElementsAttr.getValues()) { + FloatAttr dstAttr = + convertFloatAttr(srcAttr, dstElemType.cast(), rewriter); + if (!dstAttr) + return failure(); + elements.push_back(dstAttr); + } + } else if (srcElemType.isInteger(1)) { + return failure(); + } else { + for (IntegerAttr srcAttr : dstElementsAttr.getValues()) { + IntegerAttr dstAttr = convertIntegerAttr( + srcAttr, dstElemType.cast(), rewriter); + if (!dstAttr) + return failure(); + elements.push_back(dstAttr); + } + } + + // Unfortunately, we cannot use dialect-specific types for element + // attributes; element attributes only works with builtin types. So we need + // to prepare another converted builtin types for the destination elements + // attribute. + if (dstAttrType.isa()) + dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); + else + dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); + + dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); + } + + rewriter.replaceOpWithNewOp(constOp, dstType, + dstElementsAttr); + return success(); +} + +//===----------------------------------------------------------------------===// +// ConstantOp with scalar type +//===----------------------------------------------------------------------===// + +LogicalResult ConstantScalarOpPattern::matchAndRewrite( + arith::ConstantOp constOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Type srcType = constOp.getType(); + if (!srcType.isIntOrIndexOrFloat()) + return failure(); + + Type dstType = getTypeConverter()->convertType(srcType); + if (!dstType) + return failure(); + + // Floating-point types. + if (srcType.isa()) { + auto srcAttr = constOp.value().cast(); + auto dstAttr = srcAttr; + + // Floating-point types not supported in the target environment are all + // converted to float type. + if (srcType != dstType) { + dstAttr = convertFloatAttr(srcAttr, dstType.cast(), rewriter); + if (!dstAttr) + return failure(); + } + + rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); + return success(); + } + + // Bool type. + if (srcType.isInteger(1)) { + // std.constant can use 0/1 instead of true/false for i1 values. We need to + // handle that here. + auto dstAttr = convertBoolAttr(constOp.value(), rewriter); + if (!dstAttr) + return failure(); + rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); + return success(); + } + + // IndexType or IntegerType. Index values are converted to 32-bit integer + // values when converting to SPIR-V. + auto srcAttr = constOp.value().cast(); + auto dstAttr = + convertIntegerAttr(srcAttr, dstType.cast(), rewriter); + if (!dstAttr) + return failure(); + rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); + return success(); +} + +//===----------------------------------------------------------------------===// +// RemSIOpPattern +//===----------------------------------------------------------------------===// + +/// Returns signed remainder for `lhs` and `rhs` and lets the result follow +/// the sign of `signOperand`. +/// +/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment +/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative +/// the result is undefined." So we cannot directly use spv.SRem/spv.SMod +/// if either operand can be negative. Emulate it via spv.UMod. +static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, + Value signOperand, OpBuilder &builder) { + assert(lhs.getType() == rhs.getType()); + assert(lhs == signOperand || rhs == signOperand); + + Type type = lhs.getType(); + + // Calculate the remainder with spv.UMod. + Value lhsAbs = builder.create(loc, type, lhs); + Value rhsAbs = builder.create(loc, type, rhs); + Value abs = builder.create(loc, lhsAbs, rhsAbs); + + // Fix the sign. + Value isPositive; + if (lhs == signOperand) + isPositive = builder.create(loc, lhs, lhsAbs); + else + isPositive = builder.create(loc, rhs, rhsAbs); + Value absNegate = builder.create(loc, type, abs); + return builder.create(loc, type, isPositive, abs, absNegate); +} + +LogicalResult +RemSIOpPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value result = emulateSignedRemainder(op.getLoc(), adaptor.getOperands()[0], + adaptor.getOperands()[1], + adaptor.getOperands()[0], rewriter); + rewriter.replaceOp(op, result); + + return success(); +} + +//===----------------------------------------------------------------------===// +// BitwiseOpPattern +//===----------------------------------------------------------------------===// + +template +LogicalResult +BitwiseOpPattern::matchAndRewrite( + Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(adaptor.getOperands().size() == 2); + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!dstType) + return failure(); + if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { + rewriter.template replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + } else { + rewriter.template replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// XOrIOpLogicalPattern +//===----------------------------------------------------------------------===// + +LogicalResult XOrIOpLogicalPattern::matchAndRewrite( + arith::XOrIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(adaptor.getOperands().size() == 2); + + if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) + return failure(); + + auto dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return failure(); + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + + return success(); +} + +//===----------------------------------------------------------------------===// +// XOrIOpBooleanPattern +//===----------------------------------------------------------------------===// + +LogicalResult XOrIOpBooleanPattern::matchAndRewrite( + arith::XOrIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(adaptor.getOperands().size() == 2); + + if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) + return failure(); + + auto dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return failure(); + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + return success(); +} + +//===----------------------------------------------------------------------===// +// UIToFPI1Pattern +//===----------------------------------------------------------------------===// + +LogicalResult +UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto srcType = adaptor.getOperands().front().getType(); + if (!isBoolScalarOrVector(srcType)) + return failure(); + + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Location loc = op.getLoc(); + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.template replaceOpWithNewOp( + op, dstType, adaptor.getOperands().front(), one, zero); + return success(); +} + +//===----------------------------------------------------------------------===// +// ExtUII1Pattern +//===----------------------------------------------------------------------===// + +LogicalResult +ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto srcType = adaptor.getOperands().front().getType(); + if (!isBoolScalarOrVector(srcType)) + return failure(); + + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Location loc = op.getLoc(); + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.template replaceOpWithNewOp( + op, dstType, adaptor.getOperands().front(), one, zero); + return success(); +} + +//===----------------------------------------------------------------------===// +// TruncII1Pattern +//===----------------------------------------------------------------------===// + +LogicalResult +TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!isBoolScalarOrVector(dstType)) + return failure(); + + Location loc = op.getLoc(); + auto srcType = adaptor.getOperands().front().getType(); + // Check if (x & 1) == 1. + Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); + Value maskedSrc = rewriter.create( + loc, srcType, adaptor.getOperands()[0], mask); + Value isOne = rewriter.create(loc, maskedSrc, mask); + + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.replaceOpWithNewOp(op, dstType, isOne, one, zero); + return success(); +} + +//===----------------------------------------------------------------------===// +// TypeCastingOpPattern +//===----------------------------------------------------------------------===// + +template +LogicalResult TypeCastingOpPattern::matchAndRewrite( + Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(adaptor.getOperands().size() == 1); + auto srcType = adaptor.getOperands().front().getType(); + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) + return failure(); + if (dstType == srcType) { + // Due to type conversion, we are seeing the same source and target type. + // Then we can just erase this operation by forwarding its operand. + rewriter.replaceOp(op, adaptor.getOperands().front()); + } else { + rewriter.template replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// CmpIOpBooleanPattern +//===----------------------------------------------------------------------===// + +LogicalResult CmpIOpBooleanPattern::matchAndRewrite( + arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Type operandType = op.lhs().getType(); + if (!isBoolScalarOrVector(operandType)) + return failure(); + + switch (op.getPredicate()) { +#define DISPATCH(cmpPredicate, spirvOp) \ + case cmpPredicate: \ + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), \ + adaptor.lhs(), adaptor.rhs()); \ + return success(); + + DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp); + DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp); + +#undef DISPATCH + default:; + } + return failure(); +} + +//===----------------------------------------------------------------------===// +// CmpIOpPattern +//===----------------------------------------------------------------------===// + +LogicalResult +CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Type operandType = op.lhs().getType(); + if (isBoolScalarOrVector(operandType)) + return failure(); + + switch (op.getPredicate()) { +#define DISPATCH(cmpPredicate, spirvOp) \ + case cmpPredicate: \ + if (spirvOp::template hasTrait() && \ + operandType != this->getTypeConverter()->convertType(operandType)) { \ + return op.emitError( \ + "bitwidth emulation is not implemented yet on unsigned op"); \ + } \ + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), \ + adaptor.lhs(), adaptor.rhs()); \ + return success(); + + DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); + DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); + DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); + DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); + DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); + DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); + DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); + DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); + DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); + DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); + +#undef DISPATCH + } + return failure(); +} + +//===----------------------------------------------------------------------===// +// CmpFOpPattern +//===----------------------------------------------------------------------===// + +LogicalResult +CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + switch (op.getPredicate()) { +#define DISPATCH(cmpPredicate, spirvOp) \ + case cmpPredicate: \ + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), \ + adaptor.lhs(), adaptor.rhs()); \ + return success(); + + // Ordered. + DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); + DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); + DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); + DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); + DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); + DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); + // Unordered. + DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); + DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); + DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); + DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); + DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); + DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); + +#undef DISPATCH + + default: + break; + } + return failure(); +} + +//===----------------------------------------------------------------------===// +// CmpFOpNanKernelPattern +//===----------------------------------------------------------------------===// + +LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( + arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + if (op.getPredicate() == arith::CmpFPredicate::ORD) { + rewriter.replaceOpWithNewOp(op, adaptor.lhs(), + adaptor.rhs()); + return success(); + } + + if (op.getPredicate() == arith::CmpFPredicate::UNO) { + rewriter.replaceOpWithNewOp(op, adaptor.lhs(), + adaptor.rhs()); + return success(); + } + + return failure(); +} + +//===----------------------------------------------------------------------===// +// CmpFOpNanNonePattern +//===----------------------------------------------------------------------===// + +LogicalResult CmpFOpNanNonePattern::matchAndRewrite( + arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + if (op.getPredicate() != arith::CmpFPredicate::ORD && + op.getPredicate() != arith::CmpFPredicate::UNO) + return failure(); + + Location loc = op.getLoc(); + + Value lhsIsNan = rewriter.create(loc, adaptor.lhs()); + Value rhsIsNan = rewriter.create(loc, adaptor.rhs()); + + Value replace = rewriter.create(loc, lhsIsNan, rhsIsNan); + if (op.getPredicate() == arith::CmpFPredicate::ORD) + replace = rewriter.create(loc, replace); + + rewriter.replaceOp(op, replace); + return success(); +} + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// + +void mlir::arith::populateArithmeticToSPIRVPatterns( + SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { + // clang-format off + patterns.add< + ConstantCompositeOpPattern, + ConstantScalarOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + RemSIOpPattern, + BitwiseOpPattern, + BitwiseOpPattern, + XOrIOpLogicalPattern, XOrIOpBooleanPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + TypeCastingOpPattern, ExtUII1Pattern, + TypeCastingOpPattern, + TypeCastingOpPattern, + TypeCastingOpPattern, TruncII1Pattern, + TypeCastingOpPattern, + TypeCastingOpPattern, UIToFPI1Pattern, + TypeCastingOpPattern, + TypeCastingOpPattern, + TypeCastingOpPattern, + CmpIOpBooleanPattern, CmpIOpPattern, + CmpFOpNanNonePattern, CmpFOpPattern + >(typeConverter, patterns.getContext()); + // clang-format on + + // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel + // capability is available. + patterns.add(typeConverter, patterns.getContext(), + /*benefit=*/2); +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace { +struct ConvertArithmeticToSPIRVPass + : public ConvertArithmeticToSPIRVBase { + void runOnFunction() override { + auto module = getOperation()->getParentOfType(); + auto targetAttr = spirv::lookupTargetEnvOrDefault(module); + auto target = SPIRVConversionTarget::get(targetAttr); + + SPIRVTypeConverter::Options options; + options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; + SPIRVTypeConverter typeConverter(targetAttr, options); + + RewritePatternSet patterns(&getContext()); + mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(getOperation(), *target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // end anonymous namespace + +std::unique_ptr mlir::arith::createConvertArithmeticToSPIRVPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ArithmeticToSPIRV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRArithmeticToSPIRV + ArithmeticToSPIRV.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithmeticToSPIRV + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRSPIRVConversion + MLIRSPIRV + ) diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -12,6 +12,7 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -903,9 +904,9 @@ LogicalResult matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto count = - rewriter.create(op->getLoc(), rewriter.getI64Type(), - rewriter.getI64IntegerAttr(op.count())); + auto count = rewriter.create( + op->getLoc(), rewriter.getI64Type(), + rewriter.getI64IntegerAttr(op.count())); auto operand = adaptor.operand(); rewriter.replaceOpWithNewOp(op, TypeRange(), apiFunctionName, @@ -1008,7 +1009,8 @@ converter, ctx); ConversionTarget target(*ctx); - target.addLegalOp(); + target + .addLegalOp(); target.addLegalDialect(); // All operations from Async dialect must be lowered to the runtime API and diff --git a/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt b/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt @@ -11,6 +11,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRAsync MLIRLLVMCommonConversion MLIRLLVMIR diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,4 +1,6 @@ add_subdirectory(AffineToStandard) +add_subdirectory(ArithmeticToLLVM) +add_subdirectory(ArithmeticToSPIRV) add_subdirectory(ArmNeon2dToIntr) add_subdirectory(AsyncToLLVM) add_subdirectory(ComplexToLLVM) diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -11,8 +11,10 @@ #include "../PassDetail.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" using namespace mlir; using namespace mlir::LLVM; diff --git a/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt b/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt @@ -8,6 +8,7 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmetic MLIRComplex MLIRIR MLIRMath diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -12,6 +12,7 @@ #include #include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -33,21 +34,21 @@ Value real = rewriter.create(loc, type, adaptor.complex()); Value imag = rewriter.create(loc, type, adaptor.complex()); - Value realSqr = rewriter.create(loc, real, real); - Value imagSqr = rewriter.create(loc, imag, imag); - Value sqNorm = rewriter.create(loc, realSqr, imagSqr); + Value realSqr = rewriter.create(loc, real, real); + Value imagSqr = rewriter.create(loc, imag, imag); + Value sqNorm = rewriter.create(loc, realSqr, imagSqr); rewriter.replaceOpWithNewOp(op, sqNorm); return success(); } }; -template +template struct ComparisonOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; using ResultCombiner = std::conditional_t::value, - AndOp, OrOp>; + arith::AndIOp, arith::OrIOp>; LogicalResult matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, @@ -60,8 +61,10 @@ Value imagLhs = rewriter.create(loc, type, adaptor.lhs()); Value realRhs = rewriter.create(loc, type, adaptor.rhs()); Value imagRhs = rewriter.create(loc, type, adaptor.rhs()); - Value realComparison = rewriter.create(loc, p, realLhs, realRhs); - Value imagComparison = rewriter.create(loc, p, imagLhs, imagRhs); + Value realComparison = + rewriter.create(loc, p, realLhs, realRhs); + Value imagComparison = + rewriter.create(loc, p, imagLhs, imagRhs); rewriter.replaceOpWithNewOp(op, realComparison, imagComparison); @@ -138,139 +141,150 @@ // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom // // See https://dl.acm.org/citation.cfm?id=368661 for more details. - Value rhsRealImagRatio = rewriter.create(loc, rhsReal, rhsImag); - Value rhsRealImagDenom = rewriter.create( - loc, rhsImag, rewriter.create(loc, rhsRealImagRatio, rhsReal)); - Value realNumerator1 = rewriter.create( - loc, rewriter.create(loc, lhsReal, rhsRealImagRatio), lhsImag); + Value rhsRealImagRatio = + rewriter.create(loc, rhsReal, rhsImag); + Value rhsRealImagDenom = rewriter.create( + loc, rhsImag, + rewriter.create(loc, rhsRealImagRatio, rhsReal)); + Value realNumerator1 = rewriter.create( + loc, rewriter.create(loc, lhsReal, rhsRealImagRatio), + lhsImag); Value resultReal1 = - rewriter.create(loc, realNumerator1, rhsRealImagDenom); - Value imagNumerator1 = rewriter.create( - loc, rewriter.create(loc, lhsImag, rhsRealImagRatio), lhsReal); + rewriter.create(loc, realNumerator1, rhsRealImagDenom); + Value imagNumerator1 = rewriter.create( + loc, rewriter.create(loc, lhsImag, rhsRealImagRatio), + lhsReal); Value resultImag1 = - rewriter.create(loc, imagNumerator1, rhsRealImagDenom); - - Value rhsImagRealRatio = rewriter.create(loc, rhsImag, rhsReal); - Value rhsImagRealDenom = rewriter.create( - loc, rhsReal, rewriter.create(loc, rhsImagRealRatio, rhsImag)); - Value realNumerator2 = rewriter.create( - loc, lhsReal, rewriter.create(loc, lhsImag, rhsImagRealRatio)); + rewriter.create(loc, imagNumerator1, rhsRealImagDenom); + + Value rhsImagRealRatio = + rewriter.create(loc, rhsImag, rhsReal); + Value rhsImagRealDenom = rewriter.create( + loc, rhsReal, + rewriter.create(loc, rhsImagRealRatio, rhsImag)); + Value realNumerator2 = rewriter.create( + loc, lhsReal, + rewriter.create(loc, lhsImag, rhsImagRealRatio)); Value resultReal2 = - rewriter.create(loc, realNumerator2, rhsImagRealDenom); - Value imagNumerator2 = rewriter.create( - loc, lhsImag, rewriter.create(loc, lhsReal, rhsImagRealRatio)); + rewriter.create(loc, realNumerator2, rhsImagRealDenom); + Value imagNumerator2 = rewriter.create( + loc, lhsImag, + rewriter.create(loc, lhsReal, rhsImagRealRatio)); Value resultImag2 = - rewriter.create(loc, imagNumerator2, rhsImagRealDenom); + rewriter.create(loc, imagNumerator2, rhsImagRealDenom); // Consider corner cases. // Case 1. Zero denominator, numerator contains at most one NaN value. - Value zero = rewriter.create(loc, elementType, - rewriter.getZeroAttr(elementType)); - Value rhsRealAbs = rewriter.create(loc, rhsReal); - Value rhsRealIsZero = - rewriter.create(loc, CmpFPredicate::OEQ, rhsRealAbs, zero); - Value rhsImagAbs = rewriter.create(loc, rhsImag); - Value rhsImagIsZero = - rewriter.create(loc, CmpFPredicate::OEQ, rhsImagAbs, zero); - Value lhsRealIsNotNaN = - rewriter.create(loc, CmpFPredicate::ORD, lhsReal, zero); - Value lhsImagIsNotNaN = - rewriter.create(loc, CmpFPredicate::ORD, lhsImag, zero); + Value zero = rewriter.create( + loc, elementType, rewriter.getZeroAttr(elementType)); + Value rhsRealAbs = rewriter.create(loc, rhsReal); + Value rhsRealIsZero = rewriter.create( + loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); + Value rhsImagAbs = rewriter.create(loc, rhsImag); + Value rhsImagIsZero = rewriter.create( + loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); + Value lhsRealIsNotNaN = rewriter.create( + loc, arith::CmpFPredicate::ORD, lhsReal, zero); + Value lhsImagIsNotNaN = rewriter.create( + loc, arith::CmpFPredicate::ORD, lhsImag, zero); Value lhsContainsNotNaNValue = - rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); - Value resultIsInfinity = rewriter.create( + rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); + Value resultIsInfinity = rewriter.create( loc, lhsContainsNotNaNValue, - rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); - Value inf = rewriter.create( + rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); + Value inf = rewriter.create( loc, elementType, rewriter.getFloatAttr( elementType, APFloat::getInf(elementType.getFloatSemantics()))); - Value infWithSignOfRhsReal = rewriter.create(loc, inf, rhsReal); + Value infWithSignOfRhsReal = + rewriter.create(loc, inf, rhsReal); Value infinityResultReal = - rewriter.create(loc, infWithSignOfRhsReal, lhsReal); + rewriter.create(loc, infWithSignOfRhsReal, lhsReal); Value infinityResultImag = - rewriter.create(loc, infWithSignOfRhsReal, lhsImag); + rewriter.create(loc, infWithSignOfRhsReal, lhsImag); // Case 2. Infinite numerator, finite denominator. - Value rhsRealFinite = - rewriter.create(loc, CmpFPredicate::ONE, rhsRealAbs, inf); - Value rhsImagFinite = - rewriter.create(loc, CmpFPredicate::ONE, rhsImagAbs, inf); - Value rhsFinite = rewriter.create(loc, rhsRealFinite, rhsImagFinite); - Value lhsRealAbs = rewriter.create(loc, lhsReal); - Value lhsRealInfinite = - rewriter.create(loc, CmpFPredicate::OEQ, lhsRealAbs, inf); - Value lhsImagAbs = rewriter.create(loc, lhsImag); - Value lhsImagInfinite = - rewriter.create(loc, CmpFPredicate::OEQ, lhsImagAbs, inf); + Value rhsRealFinite = rewriter.create( + loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); + Value rhsImagFinite = rewriter.create( + loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); + Value rhsFinite = + rewriter.create(loc, rhsRealFinite, rhsImagFinite); + Value lhsRealAbs = rewriter.create(loc, lhsReal); + Value lhsRealInfinite = rewriter.create( + loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); + Value lhsImagAbs = rewriter.create(loc, lhsImag); + Value lhsImagInfinite = rewriter.create( + loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); Value lhsInfinite = - rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); + rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); Value infNumFiniteDenom = - rewriter.create(loc, lhsInfinite, rhsFinite); - Value one = rewriter.create( + rewriter.create(loc, lhsInfinite, rhsFinite); + Value one = rewriter.create( loc, elementType, rewriter.getFloatAttr(elementType, 1)); - Value lhsRealIsInfWithSign = rewriter.create( + Value lhsRealIsInfWithSign = rewriter.create( loc, rewriter.create(loc, lhsRealInfinite, one, zero), lhsReal); - Value lhsImagIsInfWithSign = rewriter.create( + Value lhsImagIsInfWithSign = rewriter.create( loc, rewriter.create(loc, lhsImagInfinite, one, zero), lhsImag); Value lhsRealIsInfWithSignTimesRhsReal = - rewriter.create(loc, lhsRealIsInfWithSign, rhsReal); + rewriter.create(loc, lhsRealIsInfWithSign, rhsReal); Value lhsImagIsInfWithSignTimesRhsImag = - rewriter.create(loc, lhsImagIsInfWithSign, rhsImag); - Value resultReal3 = rewriter.create( + rewriter.create(loc, lhsImagIsInfWithSign, rhsImag); + Value resultReal3 = rewriter.create( loc, inf, - rewriter.create(loc, lhsRealIsInfWithSignTimesRhsReal, - lhsImagIsInfWithSignTimesRhsImag)); + rewriter.create(loc, lhsRealIsInfWithSignTimesRhsReal, + lhsImagIsInfWithSignTimesRhsImag)); Value lhsRealIsInfWithSignTimesRhsImag = - rewriter.create(loc, lhsRealIsInfWithSign, rhsImag); + rewriter.create(loc, lhsRealIsInfWithSign, rhsImag); Value lhsImagIsInfWithSignTimesRhsReal = - rewriter.create(loc, lhsImagIsInfWithSign, rhsReal); - Value resultImag3 = rewriter.create( + rewriter.create(loc, lhsImagIsInfWithSign, rhsReal); + Value resultImag3 = rewriter.create( loc, inf, - rewriter.create(loc, lhsImagIsInfWithSignTimesRhsReal, - lhsRealIsInfWithSignTimesRhsImag)); + rewriter.create(loc, lhsImagIsInfWithSignTimesRhsReal, + lhsRealIsInfWithSignTimesRhsImag)); // Case 3: Finite numerator, infinite denominator. - Value lhsRealFinite = - rewriter.create(loc, CmpFPredicate::ONE, lhsRealAbs, inf); - Value lhsImagFinite = - rewriter.create(loc, CmpFPredicate::ONE, lhsImagAbs, inf); - Value lhsFinite = rewriter.create(loc, lhsRealFinite, lhsImagFinite); - Value rhsRealInfinite = - rewriter.create(loc, CmpFPredicate::OEQ, rhsRealAbs, inf); - Value rhsImagInfinite = - rewriter.create(loc, CmpFPredicate::OEQ, rhsImagAbs, inf); + Value lhsRealFinite = rewriter.create( + loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); + Value lhsImagFinite = rewriter.create( + loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); + Value lhsFinite = + rewriter.create(loc, lhsRealFinite, lhsImagFinite); + Value rhsRealInfinite = rewriter.create( + loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); + Value rhsImagInfinite = rewriter.create( + loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); Value rhsInfinite = - rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); + rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); Value finiteNumInfiniteDenom = - rewriter.create(loc, lhsFinite, rhsInfinite); - Value rhsRealIsInfWithSign = rewriter.create( + rewriter.create(loc, lhsFinite, rhsInfinite); + Value rhsRealIsInfWithSign = rewriter.create( loc, rewriter.create(loc, rhsRealInfinite, one, zero), rhsReal); - Value rhsImagIsInfWithSign = rewriter.create( + Value rhsImagIsInfWithSign = rewriter.create( loc, rewriter.create(loc, rhsImagInfinite, one, zero), rhsImag); Value rhsRealIsInfWithSignTimesLhsReal = - rewriter.create(loc, lhsReal, rhsRealIsInfWithSign); + rewriter.create(loc, lhsReal, rhsRealIsInfWithSign); Value rhsImagIsInfWithSignTimesLhsImag = - rewriter.create(loc, lhsImag, rhsImagIsInfWithSign); - Value resultReal4 = rewriter.create( + rewriter.create(loc, lhsImag, rhsImagIsInfWithSign); + Value resultReal4 = rewriter.create( loc, zero, - rewriter.create(loc, rhsRealIsInfWithSignTimesLhsReal, - rhsImagIsInfWithSignTimesLhsImag)); + rewriter.create(loc, rhsRealIsInfWithSignTimesLhsReal, + rhsImagIsInfWithSignTimesLhsImag)); Value rhsRealIsInfWithSignTimesLhsImag = - rewriter.create(loc, lhsImag, rhsRealIsInfWithSign); + rewriter.create(loc, lhsImag, rhsRealIsInfWithSign); Value rhsImagIsInfWithSignTimesLhsReal = - rewriter.create(loc, lhsReal, rhsImagIsInfWithSign); - Value resultImag4 = rewriter.create( + rewriter.create(loc, lhsReal, rhsImagIsInfWithSign); + Value resultImag4 = rewriter.create( loc, zero, - rewriter.create(loc, rhsRealIsInfWithSignTimesLhsImag, - rhsImagIsInfWithSignTimesLhsReal)); + rewriter.create(loc, rhsRealIsInfWithSignTimesLhsImag, + rhsImagIsInfWithSignTimesLhsReal)); - Value realAbsSmallerThanImagAbs = rewriter.create( - loc, CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); + Value realAbsSmallerThanImagAbs = rewriter.create( + loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); Value resultReal = rewriter.create(loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); Value resultImag = rewriter.create(loc, realAbsSmallerThanImagAbs, @@ -288,12 +302,12 @@ Value resultImagSpecialCase1 = rewriter.create( loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); - Value resultRealIsNaN = - rewriter.create(loc, CmpFPredicate::UNO, resultReal, zero); - Value resultImagIsNaN = - rewriter.create(loc, CmpFPredicate::UNO, resultImag, zero); + Value resultRealIsNaN = rewriter.create( + loc, arith::CmpFPredicate::UNO, resultReal, zero); + Value resultImagIsNaN = rewriter.create( + loc, arith::CmpFPredicate::UNO, resultImag, zero); Value resultIsNaN = - rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); + rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); Value resultRealWithSpecialCases = rewriter.create( loc, resultIsNaN, resultRealSpecialCase1, resultReal); Value resultImagWithSpecialCases = rewriter.create( @@ -321,9 +335,9 @@ rewriter.create(loc, elementType, adaptor.complex()); Value expReal = rewriter.create(loc, real); Value cosImag = rewriter.create(loc, imag); - Value resultReal = rewriter.create(loc, expReal, cosImag); + Value resultReal = rewriter.create(loc, expReal, cosImag); Value sinImag = rewriter.create(loc, imag); - Value resultImag = rewriter.create(loc, expReal, sinImag); + Value resultImag = rewriter.create(loc, expReal, sinImag); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); @@ -364,9 +378,9 @@ Value real = b.create(elementType, adaptor.complex()); Value imag = b.create(elementType, adaptor.complex()); - Value one = - b.create(elementType, b.getFloatAttr(elementType, 1)); - Value realPlusOne = b.create(real, one); + Value one = b.create(elementType, + b.getFloatAttr(elementType, 1)); + Value realPlusOne = b.create(real, one); Value newComplex = b.create(type, realPlusOne, imag); rewriter.replaceOpWithNewOp(op, type, newComplex); return success(); @@ -384,126 +398,162 @@ auto elementType = type.getElementType().cast(); Value lhsReal = b.create(elementType, adaptor.lhs()); - Value lhsRealAbs = b.create(lhsReal); + Value lhsRealAbs = b.create(lhsReal); Value lhsImag = b.create(elementType, adaptor.lhs()); - Value lhsImagAbs = b.create(lhsImag); + Value lhsImagAbs = b.create(lhsImag); Value rhsReal = b.create(elementType, adaptor.rhs()); - Value rhsRealAbs = b.create(rhsReal); + Value rhsRealAbs = b.create(rhsReal); Value rhsImag = b.create(elementType, adaptor.rhs()); - Value rhsImagAbs = b.create(rhsImag); + Value rhsImagAbs = b.create(rhsImag); - Value lhsRealTimesRhsReal = b.create(lhsReal, rhsReal); - Value lhsRealTimesRhsRealAbs = b.create(lhsRealTimesRhsReal); - Value lhsImagTimesRhsImag = b.create(lhsImag, rhsImag); - Value lhsImagTimesRhsImagAbs = b.create(lhsImagTimesRhsImag); - Value real = b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag); + Value lhsRealTimesRhsReal = b.create(lhsReal, rhsReal); + Value lhsRealTimesRhsRealAbs = b.create(lhsRealTimesRhsReal); + Value lhsImagTimesRhsImag = b.create(lhsImag, rhsImag); + Value lhsImagTimesRhsImagAbs = b.create(lhsImagTimesRhsImag); + Value real = + b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag); - Value lhsImagTimesRhsReal = b.create(lhsImag, rhsReal); - Value lhsImagTimesRhsRealAbs = b.create(lhsImagTimesRhsReal); - Value lhsRealTimesRhsImag = b.create(lhsReal, rhsImag); - Value lhsRealTimesRhsImagAbs = b.create(lhsRealTimesRhsImag); - Value imag = b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag); + Value lhsImagTimesRhsReal = b.create(lhsImag, rhsReal); + Value lhsImagTimesRhsRealAbs = b.create(lhsImagTimesRhsReal); + Value lhsRealTimesRhsImag = b.create(lhsReal, rhsImag); + Value lhsRealTimesRhsImagAbs = b.create(lhsRealTimesRhsImag); + Value imag = + b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag); // Handle cases where the "naive" calculation results in NaN values. - Value realIsNan = b.create(CmpFPredicate::UNO, real, real); - Value imagIsNan = b.create(CmpFPredicate::UNO, imag, imag); - Value isNan = b.create(realIsNan, imagIsNan); + Value realIsNan = + b.create(arith::CmpFPredicate::UNO, real, real); + Value imagIsNan = + b.create(arith::CmpFPredicate::UNO, imag, imag); + Value isNan = b.create(realIsNan, imagIsNan); - Value inf = b.create( + Value inf = b.create( elementType, b.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); // Case 1. `lhsReal` or `lhsImag` are infinite. - Value lhsRealIsInf = b.create(CmpFPredicate::OEQ, lhsRealAbs, inf); - Value lhsImagIsInf = b.create(CmpFPredicate::OEQ, lhsImagAbs, inf); - Value lhsIsInf = b.create(lhsRealIsInf, lhsImagIsInf); - Value rhsRealIsNan = b.create(CmpFPredicate::UNO, rhsReal, rhsReal); - Value rhsImagIsNan = b.create(CmpFPredicate::UNO, rhsImag, rhsImag); - Value zero = b.create(elementType, b.getZeroAttr(elementType)); - Value one = - b.create(elementType, b.getFloatAttr(elementType, 1)); + Value lhsRealIsInf = + b.create(arith::CmpFPredicate::OEQ, lhsRealAbs, inf); + Value lhsImagIsInf = + b.create(arith::CmpFPredicate::OEQ, lhsImagAbs, inf); + Value lhsIsInf = b.create(lhsRealIsInf, lhsImagIsInf); + Value rhsRealIsNan = + b.create(arith::CmpFPredicate::UNO, rhsReal, rhsReal); + Value rhsImagIsNan = + b.create(arith::CmpFPredicate::UNO, rhsImag, rhsImag); + Value zero = + b.create(elementType, b.getZeroAttr(elementType)); + Value one = b.create(elementType, + b.getFloatAttr(elementType, 1)); Value lhsRealIsInfFloat = b.create(lhsRealIsInf, one, zero); lhsReal = b.create( - lhsIsInf, b.create(lhsRealIsInfFloat, lhsReal), lhsReal); + lhsIsInf, b.create(lhsRealIsInfFloat, lhsReal), + lhsReal); Value lhsImagIsInfFloat = b.create(lhsImagIsInf, one, zero); lhsImag = b.create( - lhsIsInf, b.create(lhsImagIsInfFloat, lhsImag), lhsImag); - Value lhsIsInfAndRhsRealIsNan = b.create(lhsIsInf, rhsRealIsNan); - rhsReal = b.create(lhsIsInfAndRhsRealIsNan, - b.create(zero, rhsReal), rhsReal); - Value lhsIsInfAndRhsImagIsNan = b.create(lhsIsInf, rhsImagIsNan); - rhsImag = b.create(lhsIsInfAndRhsImagIsNan, - b.create(zero, rhsImag), rhsImag); + lhsIsInf, b.create(lhsImagIsInfFloat, lhsImag), + lhsImag); + Value lhsIsInfAndRhsRealIsNan = + b.create(lhsIsInf, rhsRealIsNan); + rhsReal = + b.create(lhsIsInfAndRhsRealIsNan, + b.create(zero, rhsReal), rhsReal); + Value lhsIsInfAndRhsImagIsNan = + b.create(lhsIsInf, rhsImagIsNan); + rhsImag = + b.create(lhsIsInfAndRhsImagIsNan, + b.create(zero, rhsImag), rhsImag); // Case 2. `rhsReal` or `rhsImag` are infinite. - Value rhsRealIsInf = b.create(CmpFPredicate::OEQ, rhsRealAbs, inf); - Value rhsImagIsInf = b.create(CmpFPredicate::OEQ, rhsImagAbs, inf); - Value rhsIsInf = b.create(rhsRealIsInf, rhsImagIsInf); - Value lhsRealIsNan = b.create(CmpFPredicate::UNO, lhsReal, lhsReal); - Value lhsImagIsNan = b.create(CmpFPredicate::UNO, lhsImag, lhsImag); + Value rhsRealIsInf = + b.create(arith::CmpFPredicate::OEQ, rhsRealAbs, inf); + Value rhsImagIsInf = + b.create(arith::CmpFPredicate::OEQ, rhsImagAbs, inf); + Value rhsIsInf = b.create(rhsRealIsInf, rhsImagIsInf); + Value lhsRealIsNan = + b.create(arith::CmpFPredicate::UNO, lhsReal, lhsReal); + Value lhsImagIsNan = + b.create(arith::CmpFPredicate::UNO, lhsImag, lhsImag); Value rhsRealIsInfFloat = b.create(rhsRealIsInf, one, zero); rhsReal = b.create( - rhsIsInf, b.create(rhsRealIsInfFloat, rhsReal), rhsReal); + rhsIsInf, b.create(rhsRealIsInfFloat, rhsReal), + rhsReal); Value rhsImagIsInfFloat = b.create(rhsImagIsInf, one, zero); rhsImag = b.create( - rhsIsInf, b.create(rhsImagIsInfFloat, rhsImag), rhsImag); - Value rhsIsInfAndLhsRealIsNan = b.create(rhsIsInf, lhsRealIsNan); - lhsReal = b.create(rhsIsInfAndLhsRealIsNan, - b.create(zero, lhsReal), lhsReal); - Value rhsIsInfAndLhsImagIsNan = b.create(rhsIsInf, lhsImagIsNan); - lhsImag = b.create(rhsIsInfAndLhsImagIsNan, - b.create(zero, lhsImag), lhsImag); - Value recalc = b.create(lhsIsInf, rhsIsInf); + rhsIsInf, b.create(rhsImagIsInfFloat, rhsImag), + rhsImag); + Value rhsIsInfAndLhsRealIsNan = + b.create(rhsIsInf, lhsRealIsNan); + lhsReal = + b.create(rhsIsInfAndLhsRealIsNan, + b.create(zero, lhsReal), lhsReal); + Value rhsIsInfAndLhsImagIsNan = + b.create(rhsIsInf, lhsImagIsNan); + lhsImag = + b.create(rhsIsInfAndLhsImagIsNan, + b.create(zero, lhsImag), lhsImag); + Value recalc = b.create(lhsIsInf, rhsIsInf); // Case 3. One of the pairwise products of left hand side with right hand // side is infinite. - Value lhsRealTimesRhsRealIsInf = - b.create(CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); - Value lhsImagTimesRhsImagIsInf = - b.create(CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); - Value isSpecialCase = - b.create(lhsRealTimesRhsRealIsInf, lhsImagTimesRhsImagIsInf); - Value lhsRealTimesRhsImagIsInf = - b.create(CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); - isSpecialCase = b.create(isSpecialCase, lhsRealTimesRhsImagIsInf); - Value lhsImagTimesRhsRealIsInf = - b.create(CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); - isSpecialCase = b.create(isSpecialCase, lhsImagTimesRhsRealIsInf); + Value lhsRealTimesRhsRealIsInf = b.create( + arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); + Value lhsImagTimesRhsImagIsInf = b.create( + arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); + Value isSpecialCase = b.create(lhsRealTimesRhsRealIsInf, + lhsImagTimesRhsImagIsInf); + Value lhsRealTimesRhsImagIsInf = b.create( + arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); + isSpecialCase = + b.create(isSpecialCase, lhsRealTimesRhsImagIsInf); + Value lhsImagTimesRhsRealIsInf = b.create( + arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); + isSpecialCase = + b.create(isSpecialCase, lhsImagTimesRhsRealIsInf); Type i1Type = b.getI1Type(); - Value notRecalc = b.create( - recalc, b.create(i1Type, b.getIntegerAttr(i1Type, 1))); - isSpecialCase = b.create(isSpecialCase, notRecalc); + Value notRecalc = b.create( + recalc, + b.create(i1Type, b.getIntegerAttr(i1Type, 1))); + isSpecialCase = b.create(isSpecialCase, notRecalc); Value isSpecialCaseAndLhsRealIsNan = - b.create(isSpecialCase, lhsRealIsNan); - lhsReal = b.create(isSpecialCaseAndLhsRealIsNan, - b.create(zero, lhsReal), lhsReal); + b.create(isSpecialCase, lhsRealIsNan); + lhsReal = + b.create(isSpecialCaseAndLhsRealIsNan, + b.create(zero, lhsReal), lhsReal); Value isSpecialCaseAndLhsImagIsNan = - b.create(isSpecialCase, lhsImagIsNan); - lhsImag = b.create(isSpecialCaseAndLhsImagIsNan, - b.create(zero, lhsImag), lhsImag); + b.create(isSpecialCase, lhsImagIsNan); + lhsImag = + b.create(isSpecialCaseAndLhsImagIsNan, + b.create(zero, lhsImag), lhsImag); Value isSpecialCaseAndRhsRealIsNan = - b.create(isSpecialCase, rhsRealIsNan); - rhsReal = b.create(isSpecialCaseAndRhsRealIsNan, - b.create(zero, rhsReal), rhsReal); + b.create(isSpecialCase, rhsRealIsNan); + rhsReal = + b.create(isSpecialCaseAndRhsRealIsNan, + b.create(zero, rhsReal), rhsReal); Value isSpecialCaseAndRhsImagIsNan = - b.create(isSpecialCase, rhsImagIsNan); - rhsImag = b.create(isSpecialCaseAndRhsImagIsNan, - b.create(zero, rhsImag), rhsImag); - recalc = b.create(recalc, isSpecialCase); - recalc = b.create(isNan, recalc); + b.create(isSpecialCase, rhsImagIsNan); + rhsImag = + b.create(isSpecialCaseAndRhsImagIsNan, + b.create(zero, rhsImag), rhsImag); + recalc = b.create(recalc, isSpecialCase); + recalc = b.create(isNan, recalc); // Recalculate real part. - lhsRealTimesRhsReal = b.create(lhsReal, rhsReal); - lhsImagTimesRhsImag = b.create(lhsImag, rhsImag); - Value newReal = b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag); - real = b.create(recalc, b.create(inf, newReal), real); + lhsRealTimesRhsReal = b.create(lhsReal, rhsReal); + lhsImagTimesRhsImag = b.create(lhsImag, rhsImag); + Value newReal = + b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag); + real = + b.create(recalc, b.create(inf, newReal), real); // Recalculate imag part. - lhsImagTimesRhsReal = b.create(lhsImag, rhsReal); - lhsRealTimesRhsImag = b.create(lhsReal, rhsImag); - Value newImag = b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag); - imag = b.create(recalc, b.create(inf, newImag), imag); + lhsImagTimesRhsReal = b.create(lhsImag, rhsReal); + lhsRealTimesRhsImag = b.create(lhsReal, rhsImag); + Value newImag = + b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag); + imag = + b.create(recalc, b.create(inf, newImag), imag); rewriter.replaceOpWithNewOp(op, type, real, imag); return success(); @@ -524,8 +574,8 @@ rewriter.create(loc, elementType, adaptor.complex()); Value imag = rewriter.create(loc, elementType, adaptor.complex()); - Value negReal = rewriter.create(loc, real); - Value negImag = rewriter.create(loc, imag); + Value negReal = rewriter.create(loc, real); + Value negImag = rewriter.create(loc, imag); rewriter.replaceOpWithNewOp(op, type, negReal, negImag); return success(); } @@ -543,13 +593,16 @@ Value real = b.create(elementType, adaptor.complex()); Value imag = b.create(elementType, adaptor.complex()); - Value zero = b.create(elementType, b.getZeroAttr(elementType)); - Value realIsZero = b.create(CmpFPredicate::OEQ, real, zero); - Value imagIsZero = b.create(CmpFPredicate::OEQ, imag, zero); - Value isZero = b.create(realIsZero, imagIsZero); + Value zero = + b.create(elementType, b.getZeroAttr(elementType)); + Value realIsZero = + b.create(arith::CmpFPredicate::OEQ, real, zero); + Value imagIsZero = + b.create(arith::CmpFPredicate::OEQ, imag, zero); + Value isZero = b.create(realIsZero, imagIsZero); auto abs = b.create(elementType, adaptor.complex()); - Value realSign = b.create(real, abs); - Value imagSign = b.create(imag, abs); + Value realSign = b.create(real, abs); + Value imagSign = b.create(imag, abs); Value sign = b.create(type, realSign, imagSign); rewriter.replaceOpWithNewOp(op, isZero, adaptor.complex(), sign); return success(); @@ -562,10 +615,10 @@ // clang-format off patterns.add< AbsOpConversion, - ComparisonOpConversion, - ComparisonOpConversion, - BinaryComplexOpConversion, - BinaryComplexOpConversion, + ComparisonOpConversion, + ComparisonOpConversion, + BinaryComplexOpConversion, + BinaryComplexOpConversion, DivOpConversion, ExpOpConversion, LogOpConversion, @@ -590,7 +643,8 @@ populateComplexToStandardConversionPatterns(patterns); ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalOp(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt --- a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt @@ -29,6 +29,7 @@ ${NVPTX_LIBS} LINK_LIBS PUBLIC + MLIRArithmeticToLLVM MLIRAsyncToLLVM MLIRGPUTransforms MLIRIR diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -16,6 +16,7 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" @@ -349,6 +350,7 @@ target.addIllegalDialect(); + mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns); populateMemRefToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt @@ -11,11 +11,11 @@ MLIRGPUToNVVMIncGen LINK_LIBS PUBLIC + MLIRArithmeticToLLVM MLIRGPUOps MLIRGPUToGPURuntimeTransforms MLIRLLVMCommonConversion MLIRLLVMIR - MLIRMemRef MLIRMemRefToLLVM MLIRNVVMIR MLIRPass diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -13,11 +13,13 @@ #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" @@ -169,6 +171,8 @@ populateGpuRewritePatterns(patterns); (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); + mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, + llvmPatterns); populateStdToLLVMConversionPatterns(converter, llvmPatterns); populateMemRefToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToNVVMConversionPatterns(converter, llvmPatterns); @@ -217,14 +221,14 @@ Identifier::get(NVVM::NVVMDialect::getKernelFuncAttrName(), &converter.getContext())); - patterns.add>(converter, "__nv_fabsf", - "__nv_fabs"); + patterns.add>(converter, "__nv_fabsf", + "__nv_fabs"); patterns.add>(converter, "__nv_atanf", "__nv_atan"); patterns.add>(converter, "__nv_atan2f", "__nv_atan2"); - patterns.add>(converter, "__nv_ceilf", - "__nv_ceil"); + patterns.add>(converter, "__nv_ceilf", + "__nv_ceil"); patterns.add>(converter, "__nv_cosf", "__nv_cos"); patterns.add>(converter, "__nv_expf", @@ -233,8 +237,8 @@ "__nv_exp2"); patterns.add>(converter, "__nv_expm1f", "__nv_expm1"); - patterns.add>(converter, "__nv_floorf", - "__nv_floor"); + patterns.add>(converter, "__nv_floorf", + "__nv_floor"); patterns.add>(converter, "__nv_logf", "__nv_log"); patterns.add>(converter, "__nv_log1pf", diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt @@ -10,6 +10,7 @@ MLIRGPUToROCDLIncGen LINK_LIBS PUBLIC + MLIRArithmeticToLLVM MLIRGPUOps MLIRGPUToGPURuntimeTransforms MLIRLLVMCommonConversion diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -72,6 +73,8 @@ populateGpuRewritePatterns(patterns); (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); + mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, + llvmPatterns); populateVectorToLLVMConversionPatterns(converter, llvmPatterns); populateVectorToROCDLConversionPatterns(converter, llvmPatterns); populateStdToLLVMConversionPatterns(converter, llvmPatterns); @@ -116,14 +119,14 @@ converter, /*allocaAddrSpace=*/5, Identifier::get(ROCDL::ROCDLDialect::getKernelFuncAttrName(), &converter.getContext())); - patterns.add>(converter, "__ocml_fabs_f32", - "__ocml_fabs_f64"); + patterns.add>(converter, "__ocml_fabs_f32", + "__ocml_fabs_f64"); patterns.add>(converter, "__ocml_atan_f32", "__ocml_atan_f64"); patterns.add>( converter, "__ocml_atan2_f32", "__ocml_atan2_f64"); - patterns.add>(converter, "__ocml_ceil_f32", - "__ocml_ceil_f64"); + patterns.add>(converter, "__ocml_ceil_f32", + "__ocml_ceil_f64"); patterns.add>(converter, "__ocml_cos_f32", "__ocml_cos_f64"); patterns.add>(converter, "__ocml_exp_f32", @@ -132,8 +135,8 @@ "__ocml_exp2_f64"); patterns.add>( converter, "__ocml_expm1_f32", "__ocml_expm1_f64"); - patterns.add>(converter, "__ocml_floor_f32", - "__ocml_floor_f64"); + patterns.add>( + converter, "__ocml_floor_f32", "__ocml_floor_f64"); patterns.add>(converter, "__ocml_log_f32", "__ocml_log_f64"); patterns.add>( diff --git a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt @@ -6,13 +6,13 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmeticToSPIRV MLIRGPUOps MLIRIR MLIRPass MLIRSCFToSPIRV MLIRSPIRV MLIRSPIRVConversion - MLIRStandard MLIRStandardToSPIRV MLIRSupport MLIRTransforms diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" @@ -63,6 +64,7 @@ // TODO: Change SPIR-V conversion to be progressive and remove the following // patterns. + mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); populateMemRefToSPIRVPatterns(typeConverter, patterns); populateStandardToSPIRVPatterns(typeConverter, patterns); diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -184,7 +184,8 @@ void ConvertLinalgToStandardPass::runOnOperation() { auto module = getOperation(); ConversionTarget target(getContext()); - target.addLegalDialect(); target.addLegalOp(); RewritePatternSet patterns(&getContext()); diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -18,9 +18,16 @@ using namespace mlir; namespace { +using AbsOpLowering = VectorConvertToLLVMPattern; +using CeilOpLowering = VectorConvertToLLVMPattern; +using CopySignOpLowering = + VectorConvertToLLVMPattern; using CosOpLowering = VectorConvertToLLVMPattern; using ExpOpLowering = VectorConvertToLLVMPattern; using Exp2OpLowering = VectorConvertToLLVMPattern; +using FloorOpLowering = + VectorConvertToLLVMPattern; +using FmaOpLowering = VectorConvertToLLVMPattern; using Log10OpLowering = VectorConvertToLLVMPattern; using Log2OpLowering = VectorConvertToLLVMPattern; @@ -209,10 +216,15 @@ RewritePatternSet &patterns) { // clang-format off patterns.add< + AbsOpLowering, + CeilOpLowering, + CopySignOpLowering, CosOpLowering, ExpOpLowering, Exp2OpLowering, ExpM1OpLowering, + FloorOpLowering, + FmaOpLowering, Log10OpLowering, Log1pOpLowering, Log2OpLowering, diff --git a/mlir/lib/Conversion/MathToLibm/CMakeLists.txt b/mlir/lib/Conversion/MathToLibm/CMakeLists.txt --- a/mlir/lib/Conversion/MathToLibm/CMakeLists.txt +++ b/mlir/lib/Conversion/MathToLibm/CMakeLists.txt @@ -11,6 +11,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRMath MLIRStandardOpsTransforms ) diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -61,7 +62,7 @@ if (shape.size() != 1) return failure(); - Value result = rewriter.create( + Value result = rewriter.create( loc, DenseElementsAttr::get( vecType, FloatAttr::get(vecType.getElementType(), 0.0))); for (auto i = 0; i < shape.front(); ++i) { @@ -135,8 +136,8 @@ populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalDialect(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "../SPIRVCommon/Pattern.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" @@ -29,31 +30,6 @@ // normal RewritePattern. namespace { - -/// Converts unary and binary standard operations to SPIR-V operations. -template -class UnaryAndBinaryOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(adaptor.getOperands().size() <= 2); - auto dstType = this->getTypeConverter()->convertType(operation.getType()); - if (!dstType) - return failure(); - if (SPIRVOp::template hasTrait() && - dstType != operation.getType()) { - return operation.emitError( - "bitwidth emulation is not implemented yet on unsigned op"); - } - rewriter.template replaceOpWithNewOp(operation, dstType, - adaptor.getOperands()); - return success(); - } -}; - /// Converts math.log1p to SPIR-V ops. /// /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to @@ -76,7 +52,6 @@ return success(); } }; - } // namespace //===----------------------------------------------------------------------===// @@ -86,15 +61,19 @@ namespace mlir { void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern>( + patterns.add< + Log1pOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern>( typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt b/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt --- a/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt +++ b/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt @@ -8,8 +8,9 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmetic MLIRIR MLIROpenACC - MLIRTransforms MLIRSCF + MLIRTransforms ) diff --git a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp --- a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp +++ b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp @@ -8,6 +8,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -33,7 +34,7 @@ return success(); // Condition is not a constant. - if (!op.ifCond().template getDefiningOp()) { + if (!op.ifCond().template getDefiningOp()) { auto ifOp = rewriter.create(op.getLoc(), TypeRange(), op.ifCond(), false); rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); }); diff --git a/mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt b/mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt @@ -12,6 +12,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmeticToLLVM MLIRIR MLIRLLVMCommonConversion MLIRLLVMIR diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" @@ -65,6 +66,7 @@ // Convert to OpenMP operations with LLVM IR dialect RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); + mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns); populateMemRefToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); populateOpenMPToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -23,6 +23,10 @@ class OpenACCDialect; } // end namespace acc +namespace arith { +class ArithmeticDialect; +} // end namespace arith + namespace complex { class ComplexDialect; } // end namespace complex diff --git a/mlir/lib/Conversion/SCFToGPU/CMakeLists.txt b/mlir/lib/Conversion/SCFToGPU/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToGPU/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToGPU/CMakeLists.txt @@ -11,6 +11,7 @@ LINK_LIBS PUBLIC MLIRAffine MLIRAffineToStandard + MLIRArithmetic MLIRComplex MLIRGPUTransforms MLIRIR diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -16,6 +16,7 @@ #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/ParallelLoopMapper.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -83,7 +84,8 @@ // Get a Value that corresponds to the loop step. If the step is an attribute, // materialize a corresponding constant using builder. static Value getOrCreateStep(AffineForOp forOp, OpBuilder &builder) { - return builder.create(forOp.getLoc(), forOp.getStep()); + return builder.create(forOp.getLoc(), + forOp.getStep()); } // Get a Value for the loop lower bound. If the value requires computation, @@ -169,8 +171,8 @@ // Return true if the value is obviously a constant "one". static bool isConstantOne(Value value) { - if (auto def = value.getDefiningOp()) - return def.getValue() == 1; + if (auto def = value.getDefiningOp()) + return def.value() == 1; return false; } @@ -194,11 +196,11 @@ return llvm::None; } - Value range = - builder.create(currentLoop.getLoc(), upperBound, lowerBound); + Value range = builder.create(currentLoop.getLoc(), + upperBound, lowerBound); Value step = getOrCreateStep(currentLoop, builder); if (!isConstantOne(step)) - range = builder.create(currentLoop.getLoc(), range, step); + range = builder.create(currentLoop.getLoc(), range, step); dims.push_back(range); lbs.push_back(lowerBound); @@ -222,9 +224,10 @@ OpBuilder builder(rootForOp.getOperation()); // Prepare the grid and block sizes for the launch operation. If there is // no loop mapped to a specific dimension, use constant "1" as its size. - Value constOne = (numBlockDims < 3 || numThreadDims < 3) - ? builder.create(rootForOp.getLoc(), 1) - : nullptr; + Value constOne = + (numBlockDims < 3 || numThreadDims < 3) + ? builder.create(rootForOp.getLoc(), 1) + : nullptr; Value gridSizeX = numBlockDims > 0 ? dims[0] : constOne; Value gridSizeY = numBlockDims > 1 ? dims[1] : constOne; Value gridSizeZ = numBlockDims > 2 ? dims[2] : constOne; @@ -265,10 +268,10 @@ : getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims); Value step = steps[en.index()]; if (!isConstantOne(step)) - id = builder.create(rootForOp.getLoc(), step, id); + id = builder.create(rootForOp.getLoc(), step, id); Value ivReplacement = - builder.create(rootForOp.getLoc(), *lbArgumentIt, id); + builder.create(rootForOp.getLoc(), *lbArgumentIt, id); en.value().replaceAllUsesWith(ivReplacement); std::advance(lbArgumentIt, 1); std::advance(stepArgumentIt, 1); @@ -314,33 +317,33 @@ /// `upperBound`. static Value deriveStaticUpperBound(Value upperBound, PatternRewriter &rewriter) { - if (auto op = upperBound.getDefiningOp()) { + if (auto op = upperBound.getDefiningOp()) { return op; } if (auto minOp = upperBound.getDefiningOp()) { for (const AffineExpr &result : minOp.map().getResults()) { if (auto constExpr = result.dyn_cast()) { - return rewriter.create(minOp.getLoc(), - constExpr.getValue()); + return rewriter.create(minOp.getLoc(), + constExpr.getValue()); } } } - if (auto multiplyOp = upperBound.getDefiningOp()) { - if (auto lhs = dyn_cast_or_null( + if (auto multiplyOp = upperBound.getDefiningOp()) { + if (auto lhs = dyn_cast_or_null( deriveStaticUpperBound(multiplyOp.getOperand(0), rewriter) .getDefiningOp())) - if (auto rhs = dyn_cast_or_null( + if (auto rhs = dyn_cast_or_null( deriveStaticUpperBound(multiplyOp.getOperand(1), rewriter) .getDefiningOp())) { // Assumptions about the upper bound of minimum computations no longer // work if multiplied by a negative value, so abort in this case. - if (lhs.getValue() < 0 || rhs.getValue() < 0) + if (lhs.value() < 0 || rhs.value() < 0) return {}; - return rewriter.create( - multiplyOp.getLoc(), lhs.getValue() * rhs.getValue()); + return rewriter.create( + multiplyOp.getLoc(), lhs.value() * rhs.value()); } } @@ -416,8 +419,9 @@ launchIndependent](Value val) -> Value { if (launchIndependent(val)) return val; - if (ConstantOp constOp = val.getDefiningOp()) - return rewriter.create(constOp.getLoc(), constOp.getValue()); + if (auto constOp = val.getDefiningOp()) + return rewriter.create(constOp.getLoc(), + constOp.value()); return {}; }; @@ -460,17 +464,17 @@ // conditional. If the lower-bound is constant or defined before the // launch, we can use it in the launch bounds. Otherwise fail. if (!launchIndependent(lowerBound) && - !isa_and_nonnull(lowerBound.getDefiningOp())) + !isa_and_nonnull(lowerBound.getDefiningOp())) return failure(); // The step must also be constant or defined outside of the loop nest. if (!launchIndependent(step) && - !isa_and_nonnull(step.getDefiningOp())) + !isa_and_nonnull(step.getDefiningOp())) return failure(); // If the upper-bound is constant or defined before the launch, we can // use it in the launch bounds directly. Otherwise try derive a bound. bool boundIsPrecise = launchIndependent(upperBound) || - isa_and_nonnull(upperBound.getDefiningOp()); + isa_and_nonnull(upperBound.getDefiningOp()); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(launchOp); @@ -510,8 +514,8 @@ if (!boundIsPrecise) { // We are using an approximation, create a surrounding conditional. Value originalBound = std::get<3>(config); - CmpIOp pred = rewriter.create( - loc, CmpIPredicate::slt, newIndex, + arith::CmpIOp pred = rewriter.create( + loc, arith::CmpIPredicate::slt, newIndex, cloningMap.lookupOrDefault(originalBound)); scf::IfOp ifOp = rewriter.create(loc, pred, false); rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); @@ -595,7 +599,8 @@ // Create a launch operation. We start with bound one for all grid/block // sizes. Those will be refined later as we discover them from mappings. Location loc = parallelOp.getLoc(); - Value constantOne = rewriter.create(parallelOp.getLoc(), 1); + Value constantOne = + rewriter.create(parallelOp.getLoc(), 1); gpu::LaunchOp launchOp = rewriter.create( parallelOp.getLoc(), constantOne, constantOne, constantOne, constantOne, constantOne, constantOne); diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp @@ -10,6 +10,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/SCFToGPU/SCFToGPU.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/SCF/SCF.h" diff --git a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRAnalysis + MLIRArithmetic MLIRLLVMIR MLIROpenMP MLIRSCF diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" #include "../PassDetail.h" #include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/SCF.h" @@ -248,27 +249,27 @@ // Match simple binary reductions that can be expressed with atomicrmw. Type type = reduce.operand().getType(); Block &reduction = reduce.getRegion().front(); - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, builder.getFloatAttr(type, 0.0)); return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce); } - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, builder.getIntegerAttr(type, 0)); return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce); } - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, builder.getIntegerAttr(type, 0)); return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce); } - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, builder.getIntegerAttr(type, 0)); return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce); } - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl( builder, symbolTable, reduce, builder.getIntegerAttr( @@ -279,25 +280,25 @@ // Match simple binary reductions that cannot be expressed with atomicrmw. // TODO: add atomic region using cmpxchg (which needs atomic load to be // available as an op). - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { return createDecl(builder, symbolTable, reduce, builder.getFloatAttr(type, 1.0)); } // Match select-based min/max reductions. bool isMin; - if (matchSelectReduction( - reduction, {CmpFPredicate::OLT, CmpFPredicate::OLE}, - {CmpFPredicate::OGT, CmpFPredicate::OGE}, isMin) || + if (matchSelectReduction( + reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, + {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) || matchSelectReduction( reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole}, {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) { return createDecl(builder, symbolTable, reduce, minMaxValueForFloat(type, !isMin)); } - if (matchSelectReduction( - reduction, {CmpIPredicate::slt, CmpIPredicate::sle}, - {CmpIPredicate::sgt, CmpIPredicate::sge}, isMin) || + if (matchSelectReduction( + reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, + {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) || matchSelectReduction( reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle}, {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) { @@ -307,9 +308,9 @@ isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max, decl, reduce); } - if (matchSelectReduction( - reduction, {CmpIPredicate::ult, CmpIPredicate::ule}, - {CmpIPredicate::ugt, CmpIPredicate::uge}, isMin) || + if (matchSelectReduction( + reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, + {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) || matchSelectReduction( reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule}, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) { diff --git a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt @@ -9,6 +9,7 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmeticToSPIRV MLIRMemRefToSPIRV MLIRSPIRV MLIRSPIRVConversion diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h" #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" @@ -43,6 +44,7 @@ // TODO: Change SPIR-V conversion to be progressive and remove the following // patterns. + mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); populateStandardToSPIRVPatterns(typeConverter, patterns); populateMemRefToSPIRVPatterns(typeConverter, patterns); populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); diff --git a/mlir/lib/Conversion/SCFToStandard/CMakeLists.txt b/mlir/lib/Conversion/SCFToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToStandard/CMakeLists.txt @@ -11,6 +11,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRSCF MLIRTransforms ) diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp --- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp +++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -314,7 +315,7 @@ Operation *terminator = lastBodyBlock->getTerminator(); rewriter.setInsertionPointToEnd(lastBodyBlock); auto step = forOp.step(); - auto stepped = rewriter.create(loc, iv, step).getResult(); + auto stepped = rewriter.create(loc, iv, step).getResult(); if (!stepped) return failure(); @@ -341,8 +342,8 @@ // With the body block done, we can fill in the condition block. rewriter.setInsertionPointToEnd(conditionBlock); - auto comparison = - rewriter.create(loc, CmpIPredicate::slt, iv, upperBound); + auto comparison = rewriter.create( + loc, arith::CmpIPredicate::slt, iv, upperBound); rewriter.create(loc, comparison, firstBodyBlock, ArrayRef(), endBlock, ArrayRef()); diff --git a/mlir/lib/Conversion/SPIRVCommon/Pattern.h b/mlir/lib/Conversion/SPIRVCommon/Pattern.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/SPIRVCommon/Pattern.h @@ -0,0 +1,45 @@ +//===- Pattern.h - SPIRV Common Conversion Patterns -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H +#define MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H + +#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace spirv { + +/// Converts unary and binary standard operations to SPIR-V operations. +template +class UnaryAndBinaryOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(adaptor.getOperands().size() <= 2); + auto dstType = this->getTypeConverter()->convertType(op.getType()); + if (!dstType) + return failure(); + if (SPIRVOp::template hasTrait() && + dstType != op.getType()) { + return op.emitError( + "bitwidth emulation is not implemented yet on unsigned op"); + } + rewriter.template replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + return success(); + } +}; + +} // end namespace spirv +} // end namespace mlir + +#endif // MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H diff --git a/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt b/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt @@ -11,6 +11,7 @@ intrinsics_gen LINK_LIBS PUBLIC + MLIRArithmeticToLLVM MLIRGPUOps MLIRSPIRV MLIRSPIRVUtils diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -287,6 +288,8 @@ auto *context = module.getContext(); RewritePatternSet patterns(context); LLVMTypeConverter typeConverter(context, options); + mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, + patterns); populateMemRefToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns); patterns.add(typeConverter); diff --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt @@ -17,6 +17,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRIR MLIRShape MLIRTensor diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -75,13 +76,13 @@ // number of extent tensors and shifted offsets into them. Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, ValueRange rankDiffs, Value outputDimension) { - Value one = lb.create(1); + Value one = lb.create(1); Value broadcastedDim = one; for (auto tup : llvm::zip(extentTensors, rankDiffs)) { Value shape = std::get<0>(tup); Value rankDiff = std::get<1>(tup); - Value outOfBounds = - lb.create(CmpIPredicate::ult, outputDimension, rankDiff); + Value outOfBounds = lb.create(arith::CmpIPredicate::ult, + outputDimension, rankDiff); Type indexTy = lb.getIndexType(); broadcastedDim = lb.create( @@ -97,13 +98,14 @@ // - otherwise, take the extent as-is. // Note that this logic remains correct in the presence // of dimensions of zero extent. - Value lesserRankOperandDimension = - b.create(loc, indexTy, outputDimension, rankDiff); + Value lesserRankOperandDimension = b.create( + loc, indexTy, outputDimension, rankDiff); Value lesserRankOperandExtent = b.create( loc, shape, ValueRange{lesserRankOperandDimension}); - Value dimIsOne = b.create(loc, CmpIPredicate::eq, - lesserRankOperandExtent, one); + Value dimIsOne = + b.create(loc, arith::CmpIPredicate::eq, + lesserRankOperandExtent, one); Value dim = b.create(loc, dimIsOne, broadcastedDim, lesserRankOperandExtent); b.create(loc, dim); @@ -125,7 +127,7 @@ auto loc = op.getLoc(); ImplicitLocOpBuilder lb(loc, rewriter); - Value zero = lb.create(0); + Value zero = lb.create(0); Type indexTy = lb.getIndexType(); // Save all the ranks for bounds checking. Because this is a tensor @@ -139,13 +141,14 @@ // Find the maximum rank Value maxRank = ranks.front(); for (Value v : llvm::drop_begin(ranks, 1)) { - Value rankIsGreater = lb.create(CmpIPredicate::ugt, v, maxRank); + Value rankIsGreater = + lb.create(arith::CmpIPredicate::ugt, v, maxRank); maxRank = lb.create(rankIsGreater, v, maxRank); } // Calculate the difference of ranks and the maximum rank for later offsets. llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { - return lb.create(indexTy, maxRank, v); + return lb.create(indexTy, maxRank, v); })); Value replacement = lb.create( @@ -186,7 +189,7 @@ SmallVector extentOperands; for (auto extent : op.shape()) { extentOperands.push_back( - rewriter.create(loc, extent.getLimitedValue())); + rewriter.create(loc, extent.getLimitedValue())); } Type indexTy = rewriter.getIndexType(); Value tensor = @@ -210,7 +213,8 @@ LogicalResult ConstSizeOpConversion::matchAndRewrite( ConstSizeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - rewriter.replaceOpWithNewOp(op, op.value().getSExtValue()); + rewriter.replaceOpWithNewOp( + op, op.value().getSExtValue()); return success(); } @@ -236,8 +240,8 @@ auto loc = op.getLoc(); ImplicitLocOpBuilder lb(loc, rewriter); - Value zero = lb.create(0); - Value one = lb.create(1); + Value zero = lb.create(0); + Value one = lb.create(1); Type indexTy = lb.getIndexType(); // Save all the ranks for bounds checking. Because this is a tensor @@ -251,18 +255,19 @@ // Find the maximum rank Value maxRank = ranks.front(); for (Value v : llvm::drop_begin(ranks, 1)) { - Value rankIsGreater = lb.create(CmpIPredicate::ugt, v, maxRank); + Value rankIsGreater = + lb.create(arith::CmpIPredicate::ugt, v, maxRank); maxRank = lb.create(rankIsGreater, v, maxRank); } // Calculate the difference of ranks and the maximum rank for later offsets. llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { - return lb.create(indexTy, maxRank, v); + return lb.create(indexTy, maxRank, v); })); Type i1Ty = rewriter.getI1Type(); Value trueVal = - rewriter.create(loc, i1Ty, rewriter.getBoolAttr(true)); + rewriter.create(loc, i1Ty, rewriter.getBoolAttr(true)); auto reduceResult = lb.create( loc, zero, maxRank, one, ValueRange{trueVal}, @@ -277,8 +282,8 @@ for (auto tup : llvm::zip(adaptor.shapes(), rankDiffs)) { Value shape, rankDiff; std::tie(shape, rankDiff) = tup; - Value outOfBounds = - b.create(loc, CmpIPredicate::ult, iv, rankDiff); + Value outOfBounds = b.create( + loc, arith::CmpIPredicate::ult, iv, rankDiff); broadcastable = b.create( loc, TypeRange{i1Ty}, outOfBounds, @@ -290,18 +295,19 @@ // Every value needs to be either 1, or the same non-1 // value to be broadcastable in this dim. Value operandDimension = - b.create(loc, indexTy, iv, rankDiff); + b.create(loc, indexTy, iv, rankDiff); Value dimensionExtent = b.create( loc, shape, ValueRange{operandDimension}); - Value equalOne = b.create(loc, CmpIPredicate::eq, - dimensionExtent, one); - Value equalBroadcasted = - b.create(loc, CmpIPredicate::eq, - dimensionExtent, broadcastedDim); - Value result = b.create( + Value equalOne = b.create( + loc, arith::CmpIPredicate::eq, dimensionExtent, one); + Value equalBroadcasted = b.create( + loc, arith::CmpIPredicate::eq, dimensionExtent, + broadcastedDim); + Value result = b.create( loc, broadcastable, - b.create(loc, equalOne, equalBroadcasted)); + b.create(loc, equalOne, + equalBroadcasted)); b.create(loc, result); }) .getResult(0); @@ -389,8 +395,8 @@ auto loc = op.getLoc(); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); Type indexTy = rewriter.getIndexType(); Value rank = rewriter.create(loc, indexTy, adaptor.shape(), zero); @@ -433,20 +439,20 @@ /// %c0 = constant 0 : index /// %0 = dim %arg0, %c0 : tensor /// %1 = dim %arg1, %c0 : tensor -/// %2 = cmpi "eq", %0, %1 : index +/// %2 = arith.cmpi "eq", %0, %1 : index /// %result = scf.if %2 -> (i1) { -/// %c1 = constant 1 : index -/// %true = constant true +/// %c1 = arith.constant 1 : index +/// %true = arith.constant true /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { /// %5 = tensor.extract %arg0[%arg2] : tensor /// %6 = tensor.extract %arg1[%arg2] : tensor -/// %7 = cmpi "eq", %5, %6 : index -/// %8 = and %arg3, %7 : i1 +/// %7 = arith.cmpi "eq", %5, %6 : index +/// %8 = arith.andi %arg3, %7 : i1 /// scf.yield %8 : i1 /// } /// scf.yield %4 : i1 /// } else { -/// %false = constant false +/// %false = arith.constant false /// scf.yield %false : i1 /// } /// @@ -468,14 +474,14 @@ Type i1Ty = rewriter.getI1Type(); if (op.shapes().size() <= 1) { - rewriter.replaceOpWithNewOp(op, i1Ty, - rewriter.getBoolAttr(true)); + rewriter.replaceOpWithNewOp(op, i1Ty, + rewriter.getBoolAttr(true)); return success(); } auto loc = op.getLoc(); Type indexTy = rewriter.getIndexType(); - Value zero = rewriter.create(loc, 0); + Value zero = rewriter.create(loc, 0); Value firstShape = adaptor.shapes().front(); Value firstRank = rewriter.create(loc, indexTy, firstShape, zero); @@ -483,13 +489,14 @@ // Generate a linear sequence of compares, all with firstShape as lhs. for (Value shape : adaptor.shapes().drop_front(1)) { Value rank = rewriter.create(loc, indexTy, shape, zero); - Value eqRank = - rewriter.create(loc, CmpIPredicate::eq, firstRank, rank); + Value eqRank = rewriter.create(loc, arith::CmpIPredicate::eq, + firstRank, rank); auto same = rewriter.create( loc, i1Ty, eqRank, [&](OpBuilder &b, Location loc) { - Value one = b.create(loc, 1); - Value init = b.create(loc, i1Ty, b.getBoolAttr(true)); + Value one = b.create(loc, 1); + Value init = + b.create(loc, i1Ty, b.getBoolAttr(true)); auto loop = b.create( loc, zero, firstRank, one, ValueRange{init}, [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { @@ -497,19 +504,21 @@ Value lhsExtent = b.create(loc, firstShape, iv); Value rhsExtent = b.create(loc, shape, iv); - Value eqExtent = b.create(loc, CmpIPredicate::eq, - lhsExtent, rhsExtent); - Value conjNext = b.create(loc, conj, eqExtent); + Value eqExtent = b.create( + loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent); + Value conjNext = b.create(loc, conj, eqExtent); b.create(loc, ValueRange({conjNext})); }); b.create(loc, loop.getResults()); }, [&](OpBuilder &b, Location loc) { - Value result = b.create(loc, i1Ty, b.getBoolAttr(false)); + Value result = + b.create(loc, i1Ty, b.getBoolAttr(false)); b.create(loc, result); }); result = !result ? same.getResult(0) - : rewriter.create(loc, result, same.getResult(0)); + : rewriter.create(loc, result, + same.getResult(0)); } rewriter.replaceOp(op, result); return success(); @@ -549,8 +558,8 @@ Value extent = rewriter.create(loc, tensor, i); extentValues.push_back(extent); } else { - Value extent = - rewriter.create(loc, rankedTensorTy.getDimSize(i)); + Value extent = rewriter.create( + loc, rankedTensorTy.getDimSize(i)); extentValues.push_back(extent); } } @@ -598,20 +607,20 @@ return failure(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value zero = b.create(0); + Value zero = b.create(0); Value rank = b.create(adaptor.operand(), zero); // index < 0 ? index + rank : index Value originalIndex = adaptor.index(); - Value add = b.create(originalIndex, rank); + Value add = b.create(originalIndex, rank); Value indexIsNegative = - b.create(CmpIPredicate::slt, originalIndex, zero); + b.create(arith::CmpIPredicate::slt, originalIndex, zero); Value index = b.create(indexIsNegative, add, originalIndex); - Value one = b.create(1); + Value one = b.create(1); Value head = b.create(adaptor.operand(), zero, index, one); - Value tailSize = b.create(rank, index); + Value tailSize = b.create(rank, index); Value tail = b.create(adaptor.operand(), index, tailSize, one); rewriter.replaceOp(op, {head, tail}); @@ -655,8 +664,8 @@ // Setup target legality. MLIRContext &ctx = getContext(); ConversionTarget target(ctx); - target - .addLegalDialect(); + target.addLegalDialect(); target.addLegalOp(); // Setup conversion patterns. @@ -675,8 +684,8 @@ populateWithGenerated(patterns); patterns.add< AnyOpConversion, - BinaryOpConversion, - BinaryOpConversion, + BinaryOpConversion, + BinaryOpConversion, BroadcastOpConverter, ConstShapeOpConverter, ConstSizeOpConversion, diff --git a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt @@ -13,6 +13,7 @@ LINK_LIBS PUBLIC MLIRAnalysis + MLIRArithmeticToLLVM MLIRDataLayoutInterfaces MLIRLLVMCommonConversion MLIRLLVMIR diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -13,6 +13,7 @@ #include "../PassDetail.h" #include "mlir/Analysis/DataLayoutAnalysis.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" @@ -20,14 +21,12 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" @@ -390,54 +389,7 @@ }; // Straightforward lowerings. -using AbsFOpLowering = VectorConvertToLLVMPattern; -using AddFOpLowering = VectorConvertToLLVMPattern; -using AddIOpLowering = VectorConvertToLLVMPattern; -using AndOpLowering = VectorConvertToLLVMPattern; -using BitcastOpLowering = - VectorConvertToLLVMPattern; -using CeilFOpLowering = VectorConvertToLLVMPattern; -using CopySignOpLowering = - VectorConvertToLLVMPattern; -using DivFOpLowering = VectorConvertToLLVMPattern; -using FPExtOpLowering = VectorConvertToLLVMPattern; -using FPToSIOpLowering = VectorConvertToLLVMPattern; -using FPToUIOpLowering = VectorConvertToLLVMPattern; -using FPTruncOpLowering = - VectorConvertToLLVMPattern; -using FloorFOpLowering = VectorConvertToLLVMPattern; -using FmaFOpLowering = VectorConvertToLLVMPattern; -using MulFOpLowering = VectorConvertToLLVMPattern; -using MulIOpLowering = VectorConvertToLLVMPattern; -using NegFOpLowering = VectorConvertToLLVMPattern; -using OrOpLowering = VectorConvertToLLVMPattern; -using RemFOpLowering = VectorConvertToLLVMPattern; -using SIToFPOpLowering = VectorConvertToLLVMPattern; using SelectOpLowering = VectorConvertToLLVMPattern; -using SignExtendIOpLowering = - VectorConvertToLLVMPattern; -using ShiftLeftOpLowering = - VectorConvertToLLVMPattern; -using SignedDivIOpLowering = - VectorConvertToLLVMPattern; -using SignedRemIOpLowering = - VectorConvertToLLVMPattern; -using SignedShiftRightOpLowering = - VectorConvertToLLVMPattern; -using SubFOpLowering = VectorConvertToLLVMPattern; -using SubIOpLowering = VectorConvertToLLVMPattern; -using TruncateIOpLowering = - VectorConvertToLLVMPattern; -using UIToFPOpLowering = VectorConvertToLLVMPattern; -using UnsignedDivIOpLowering = - VectorConvertToLLVMPattern; -using UnsignedRemIOpLowering = - VectorConvertToLLVMPattern; -using UnsignedShiftRightOpLowering = - VectorConvertToLLVMPattern; -using XOrOpLowering = VectorConvertToLLVMPattern; -using ZeroExtendIOpLowering = - VectorConvertToLLVMPattern; /// Lower `std.assert`. The default lowering calls the `abort` function if the /// assertion is violated and has no effect otherwise. The failure message is @@ -651,118 +603,6 @@ } }; -// The lowering of index_cast becomes an integer conversion since index becomes -// an integer. If the bit width of the source and target integer types is the -// same, just erase the cast. If the target type is wider, sign-extend the -// value, otherwise truncate it. -struct IndexCastOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(IndexCastOp indexCastOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto targetType = - typeConverter->convertType(indexCastOp.getResult().getType()); - auto targetElementType = - typeConverter - ->convertType(getElementTypeOrSelf(indexCastOp.getResult())) - .cast(); - auto sourceElementType = - getElementTypeOrSelf(adaptor.in()).cast(); - unsigned targetBits = targetElementType.getWidth(); - unsigned sourceBits = sourceElementType.getWidth(); - - if (targetBits == sourceBits) - rewriter.replaceOp(indexCastOp, adaptor.in()); - else if (targetBits < sourceBits) - rewriter.replaceOpWithNewOp(indexCastOp, targetType, - adaptor.in()); - else - rewriter.replaceOpWithNewOp(indexCastOp, targetType, - adaptor.in()); - return success(); - } -}; - -// Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two -// enums share the numerical values so just cast. -template -static LLVMPredType convertCmpPredicate(StdPredType pred) { - return static_cast(pred); -} - -struct CmpIOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(CmpIOp cmpiOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto operandType = adaptor.lhs().getType(); - auto resultType = cmpiOp.getResult().getType(); - - // Handle the scalar and 1D vector cases. - if (!operandType.isa()) { - rewriter.replaceOpWithNewOp( - cmpiOp, typeConverter->convertType(resultType), - convertCmpPredicate(cmpiOp.getPredicate()), - adaptor.lhs(), adaptor.rhs()); - return success(); - } - - auto vectorType = resultType.dyn_cast(); - if (!vectorType) - return rewriter.notifyMatchFailure(cmpiOp, "expected vector result type"); - - return LLVM::detail::handleMultidimensionalVectors( - cmpiOp.getOperation(), adaptor.getOperands(), *getTypeConverter(), - [&](Type llvm1DVectorTy, ValueRange operands) { - CmpIOpAdaptor adaptor(operands); - return rewriter.create( - cmpiOp.getLoc(), llvm1DVectorTy, - convertCmpPredicate(cmpiOp.getPredicate()), - adaptor.lhs(), adaptor.rhs()); - }, - rewriter); - - return success(); - } -}; - -struct CmpFOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(CmpFOp cmpfOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto operandType = adaptor.lhs().getType(); - auto resultType = cmpfOp.getResult().getType(); - - // Handle the scalar and 1D vector cases. - if (!operandType.isa()) { - rewriter.replaceOpWithNewOp( - cmpfOp, typeConverter->convertType(resultType), - convertCmpPredicate(cmpfOp.getPredicate()), - adaptor.lhs(), adaptor.rhs()); - return success(); - } - - auto vectorType = resultType.dyn_cast(); - if (!vectorType) - return rewriter.notifyMatchFailure(cmpfOp, "expected vector result type"); - - return LLVM::detail::handleMultidimensionalVectors( - cmpfOp.getOperation(), adaptor.getOperands(), *getTypeConverter(), - [&](Type llvm1DVectorTy, ValueRange operands) { - CmpFOpAdaptor adaptor(operands); - return rewriter.create( - cmpfOp.getLoc(), llvm1DVectorTy, - convertCmpPredicate(cmpfOp.getPredicate()), - adaptor.lhs(), adaptor.rhs()); - }, - rewriter); - } -}; - // Base class for LLVM IR lowering terminator operations with successors. template struct OneToOneLLVMTerminatorLowering @@ -1131,57 +971,20 @@ populateStdToLLVMFuncOpConversionPattern(converter, patterns); // clang-format off patterns.add< - AbsFOpLowering, - AddFOpLowering, - AddIOpLowering, - AndOpLowering, AssertOpLowering, AtomicRMWOpLowering, - BitcastOpLowering, BranchOpLowering, CallIndirectOpLowering, CallOpLowering, - CeilFOpLowering, - CmpFOpLowering, - CmpIOpLowering, CondBranchOpLowering, - CopySignOpLowering, ConstantOpLowering, - DivFOpLowering, - FloorFOpLowering, - FmaFOpLowering, GenericAtomicRMWOpLowering, - FPExtOpLowering, - FPToSIOpLowering, - FPToUIOpLowering, - FPTruncOpLowering, - IndexCastOpLowering, - MulFOpLowering, - MulIOpLowering, - NegFOpLowering, - OrOpLowering, - RemFOpLowering, RankOpLowering, ReturnOpLowering, - SIToFPOpLowering, SelectOpLowering, - ShiftLeftOpLowering, - SignExtendIOpLowering, - SignedDivIOpLowering, - SignedRemIOpLowering, - SignedShiftRightOpLowering, SplatOpLowering, SplatNdOpLowering, - SubFOpLowering, - SubIOpLowering, - SwitchOpLowering, - TruncateIOpLowering, - UIToFPOpLowering, - UnsignedDivIOpLowering, - UnsignedRemIOpLowering, - UnsignedShiftRightOpLowering, - XOrOpLowering, - ZeroExtendIOpLowering>(converter); + SwitchOpLowering>(converter); // clang-format on } @@ -1231,6 +1034,7 @@ RewritePatternSet patterns(&getContext()); populateStdToLLVMConversionPatterns(typeConverter, patterns); + arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(m, target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt @@ -10,8 +10,9 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmeticToSPIRV MLIRIR - MLIRMath + MLIRMathToSPIRV MLIRMemRef MLIRPass MLIRSPIRV diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "../SPIRVCommon/Pattern.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" @@ -29,15 +30,6 @@ // Utility functions //===----------------------------------------------------------------------===// -/// Returns true if the given `type` is a boolean scalar or vector type. -static bool isBoolScalarOrVector(Type type) { - if (type.isInteger(1)) - return true; - if (auto vecType = type.dyn_cast()) - return vecType.getElementType().isInteger(1); - return false; -} - /// Converts the given `srcAttr` into a boolean attribute if it holds an /// integral value. Returns null attribute if conversion fails. static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { @@ -98,35 +90,6 @@ return builder.getF32FloatAttr(dstVal.convertToFloat()); } -/// Returns signed remainder for `lhs` and `rhs` and lets the result follow -/// the sign of `signOperand`. -/// -/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment -/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative -/// the result is undefined." So we cannot directly use spv.SRem/spv.SMod -/// if either operand can be negative. Emulate it via spv.UMod. -static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, - Value signOperand, OpBuilder &builder) { - assert(lhs.getType() == rhs.getType()); - assert(lhs == signOperand || rhs == signOperand); - - Type type = lhs.getType(); - - // Calculate the remainder with spv.UMod. - Value lhsAbs = builder.create(loc, type, lhs); - Value rhsAbs = builder.create(loc, type, rhs); - Value abs = builder.create(loc, lhsAbs, rhsAbs); - - // Fix the sign. - Value isPositive; - if (lhs == signOperand) - isPositive = builder.create(loc, lhs, lhsAbs); - else - isPositive = builder.create(loc, rhs, rhsAbs); - Value absNegate = builder.create(loc, type, abs); - return builder.create(loc, type, isPositive, abs, absNegate); -} - //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -137,71 +100,6 @@ namespace { -/// Converts unary and binary standard operations to SPIR-V operations. -template -class UnaryAndBinaryOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(adaptor.getOperands().size() <= 2); - auto dstType = this->getTypeConverter()->convertType(operation.getType()); - if (!dstType) - return failure(); - if (SPIRVOp::template hasTrait() && - dstType != operation.getType()) { - return operation.emitError( - "bitwidth emulation is not implemented yet on unsigned op"); - } - rewriter.template replaceOpWithNewOp(operation, dstType, - adaptor.getOperands()); - return success(); - } -}; - -/// Converts std.remi_signed to SPIR-V ops. -/// -/// This cannot be merged into the template unary/binary pattern due to -/// Vulkan restrictions over spv.SRem and spv.SMod. -class SignedRemIOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(SignedRemIOp remOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts bitwise standard operations to SPIR-V operations. This is a special -/// pattern other than the BinaryOpPatternPattern because if the operands are -/// boolean values, SPIR-V uses different operations (`SPIRVLogicalOp`). For -/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. -template -class BitwiseOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(adaptor.getOperands().size() == 2); - auto dstType = - this->getTypeConverter()->convertType(operation.getResult().getType()); - if (!dstType) - return failure(); - if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { - rewriter.template replaceOpWithNewOp( - operation, dstType, adaptor.getOperands()); - } else { - rewriter.template replaceOpWithNewOp( - operation, dstType, adaptor.getOperands()); - } - return success(); - } -}; - /// Converts composite std.constant operation to spv.Constant. class ConstantCompositeOpPattern final : public OpConversionPattern { @@ -223,58 +121,6 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts floating-point comparison operations to SPIR-V ops. -class CmpFOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts floating point NaN check to SPIR-V ops. This pattern requires -/// Kernel capability. -class CmpFOpNanKernelPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts floating point NaN check to SPIR-V ops. This pattern does not -/// require additional capability. -class CmpFOpNanNonePattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts integer compare operation on i1 type operands to SPIR-V ops. -class BoolCmpIOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts integer compare operation to SPIR-V ops. -class CmpIOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - /// Converts std.return to spv.Return. class ReturnOpPattern final : public OpConversionPattern { public: @@ -304,30 +150,6 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts std.zexti to spv.Select if the type of source is i1 or vector of -/// i1. -class ZeroExtendI1Pattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ZeroExtendIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto srcType = adaptor.getOperands().front().getType(); - if (!isBoolScalarOrVector(srcType)) - return failure(); - - auto dstType = - this->getTypeConverter()->convertType(op.getResult().getType()); - Location loc = op.getLoc(); - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - rewriter.template replaceOpWithNewOp( - op, dstType, adaptor.getOperands().front(), one, zero); - return success(); - } -}; - /// Converts tensor.extract into loading using access chains from SPIR-V local /// variables. class TensorExtractPattern final @@ -389,124 +211,8 @@ int64_t byteCountThreshold; }; -/// Converts std.trunci to spv.Select if the type of result is i1 or vector of -/// i1. -class TruncI1Pattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(TruncateIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto dstType = - this->getTypeConverter()->convertType(op.getResult().getType()); - if (!isBoolScalarOrVector(dstType)) - return failure(); - - Location loc = op.getLoc(); - auto srcType = adaptor.getOperands().front().getType(); - // Check if (x & 1) == 1. - Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); - Value maskedSrc = rewriter.create( - loc, srcType, adaptor.getOperands()[0], mask); - Value isOne = rewriter.create(loc, maskedSrc, mask); - - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - rewriter.replaceOpWithNewOp(op, dstType, isOne, one, zero); - return success(); - } -}; - -/// Converts std.uitofp to spv.Select if the type of source is i1 or vector of -/// i1. -class UIToFPI1Pattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(UIToFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto srcType = adaptor.getOperands().front().getType(); - if (!isBoolScalarOrVector(srcType)) - return failure(); - - auto dstType = - this->getTypeConverter()->convertType(op.getResult().getType()); - Location loc = op.getLoc(); - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - rewriter.template replaceOpWithNewOp( - op, dstType, adaptor.getOperands().front(), one, zero); - return success(); - } -}; - -/// Converts type-casting standard operations to SPIR-V operations. -template -class TypeCastingOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(adaptor.getOperands().size() == 1); - auto srcType = adaptor.getOperands().front().getType(); - auto dstType = - this->getTypeConverter()->convertType(operation.getResult().getType()); - if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) - return failure(); - if (dstType == srcType) { - // Due to type conversion, we are seeing the same source and target type. - // Then we can just erase this operation by forwarding its operand. - rewriter.replaceOp(operation, adaptor.getOperands().front()); - } else { - rewriter.template replaceOpWithNewOp(operation, dstType, - adaptor.getOperands()); - } - return success(); - } -}; - -/// Converts std.xor to SPIR-V operations. -class XOrOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts std.xor to SPIR-V operations if the type of source is i1 or vector -/// of i1. -class BoolXOrOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - } // namespace -//===----------------------------------------------------------------------===// -// SignedRemIOpPattern -//===----------------------------------------------------------------------===// - -LogicalResult SignedRemIOpPattern::matchAndRewrite( - SignedRemIOp remOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value result = emulateSignedRemainder( - remOp.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], - adaptor.getOperands()[0], rewriter); - rewriter.replaceOp(remOp, result); - - return success(); -} - //===----------------------------------------------------------------------===// // ConstantOp with composite type. //===----------------------------------------------------------------------===// @@ -649,143 +355,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// CmpFOp -//===----------------------------------------------------------------------===// - -LogicalResult -CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - switch (cmpFOp.getPredicate()) { -#define DISPATCH(cmpPredicate, spirvOp) \ - case cmpPredicate: \ - rewriter.replaceOpWithNewOp(cmpFOp, cmpFOp.getResult().getType(), \ - adaptor.lhs(), adaptor.rhs()); \ - return success(); - - // Ordered. - DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp); - DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); - DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); - DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp); - DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); - DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp); - // Unordered. - DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp); - DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); - DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); - DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp); - DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); - DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp); - -#undef DISPATCH - - default: - break; - } - return failure(); -} - -LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( - CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (cmpFOp.getPredicate() == CmpFPredicate::ORD) { - rewriter.replaceOpWithNewOp(cmpFOp, adaptor.lhs(), - adaptor.rhs()); - return success(); - } - - if (cmpFOp.getPredicate() == CmpFPredicate::UNO) { - rewriter.replaceOpWithNewOp(cmpFOp, adaptor.lhs(), - adaptor.rhs()); - return success(); - } - - return failure(); -} - -LogicalResult CmpFOpNanNonePattern::matchAndRewrite( - CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (cmpFOp.getPredicate() != CmpFPredicate::ORD && - cmpFOp.getPredicate() != CmpFPredicate::UNO) - return failure(); - - Location loc = cmpFOp.getLoc(); - - Value lhsIsNan = rewriter.create(loc, adaptor.lhs()); - Value rhsIsNan = rewriter.create(loc, adaptor.rhs()); - - Value replace = rewriter.create(loc, lhsIsNan, rhsIsNan); - if (cmpFOp.getPredicate() == CmpFPredicate::ORD) - replace = rewriter.create(loc, replace); - - rewriter.replaceOp(cmpFOp, replace); - return success(); -} - -//===----------------------------------------------------------------------===// -// CmpIOp -//===----------------------------------------------------------------------===// - -LogicalResult -BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Type operandType = cmpIOp.lhs().getType(); - if (!isBoolScalarOrVector(operandType)) - return failure(); - - switch (cmpIOp.getPredicate()) { -#define DISPATCH(cmpPredicate, spirvOp) \ - case cmpPredicate: \ - rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ - adaptor.lhs(), adaptor.rhs()); \ - return success(); - - DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp); - DISPATCH(CmpIPredicate::ne, spirv::LogicalNotEqualOp); - -#undef DISPATCH - default:; - } - return failure(); -} - -LogicalResult -CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Type operandType = cmpIOp.lhs().getType(); - if (isBoolScalarOrVector(operandType)) - return failure(); - - switch (cmpIOp.getPredicate()) { -#define DISPATCH(cmpPredicate, spirvOp) \ - case cmpPredicate: \ - if (spirvOp::template hasTrait() && \ - operandType != this->getTypeConverter()->convertType(operandType)) { \ - return cmpIOp.emitError( \ - "bitwidth emulation is not implemented yet on unsigned op"); \ - } \ - rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ - adaptor.lhs(), adaptor.rhs()); \ - return success(); - - DISPATCH(CmpIPredicate::eq, spirv::IEqualOp); - DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp); - DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp); - DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp); - DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp); - DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp); - DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp); - DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp); - DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp); - DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp); - -#undef DISPATCH - } - return failure(); -} - //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// @@ -833,43 +402,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// XorOp -//===----------------------------------------------------------------------===// - -LogicalResult -XOrOpPattern::matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(adaptor.getOperands().size() == 2); - - if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) - return failure(); - - auto dstType = getTypeConverter()->convertType(xorOp.getType()); - if (!dstType) - return failure(); - rewriter.replaceOpWithNewOp(xorOp, dstType, - adaptor.getOperands()); - - return success(); -} - -LogicalResult -BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(adaptor.getOperands().size() == 2); - - if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) - return failure(); - - auto dstType = getTypeConverter()->convertType(xorOp.getType()); - if (!dstType) - return failure(); - rewriter.replaceOpWithNewOp(xorOp, dstType, - adaptor.getOperands()); - return success(); -} - //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// @@ -881,60 +413,17 @@ patterns.add< // Unary and binary patterns - BitwiseOpPattern, - BitwiseOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern, - - // Comparison patterns - BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, // Constant patterns ConstantCompositeOpPattern, ConstantScalarOpPattern, - ReturnOpPattern, SelectOpPattern, SplatPattern, - - // Type cast patterns - UIToFPI1Pattern, ZeroExtendI1Pattern, TruncI1Pattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern>(typeConverter, - context); - - // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel - // capability is available. - patterns.add(typeConverter, context, - /*benefit=*/2); + ReturnOpPattern, SelectOpPattern, SplatPattern>(typeConverter, context); } void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp @@ -12,6 +12,8 @@ #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h" #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" +#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" @@ -38,10 +40,13 @@ options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; SPIRVTypeConverter typeConverter(targetAttr, options); + // TODO ArithmeticToSPIRV cannot be applied separately to StandardToSPIRV RewritePatternSet patterns(context); + arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); + populateMathToSPIRVPatterns(typeConverter, patterns); populateStandardToSPIRVPatterns(typeConverter, patterns); - populateTensorToSPIRVPatterns(typeConverter, - /*byteCountThreshold=*/64, patterns); + populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64, + patterns); populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); if (failed(applyPartialConversion(module, *target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt --- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt +++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt @@ -10,6 +10,7 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmetic MLIRDialectUtils MLIRIR MLIRLinalg diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/SCF.h" @@ -32,12 +33,12 @@ } template -static mlir::ConstantOp +static arith::ConstantOp createConstFromIntAttribute(Operation *op, std::string attrName, Type requiredAttrType, OpBuilder &rewriter) { auto castedN = static_cast( op->getAttr(attrName).cast().getValue().getSExtValue()); - return rewriter.create( + return rewriter.create( op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); } @@ -50,9 +51,9 @@ } template -static mlir::SelectOp clampHelper(Location loc, Value arg, mlir::ConstantOp min, - mlir::ConstantOp max, P pred, - OpBuilder &rewriter) { +static mlir::SelectOp clampHelper(Location loc, Value arg, + arith::ConstantOp min, arith::ConstantOp max, + P pred, OpBuilder &rewriter) { auto smallerThanMin = rewriter.create(loc, pred, arg, min); auto minOrArg = rewriter.create(loc, smallerThanMin, min, arg); @@ -83,7 +84,7 @@ highIndices.push_back(rewriter.getIndexAttr(highPad)); } - Value padValue = rewriter.create(loc, padAttr); + Value padValue = rewriter.create(loc, padAttr); return linalg::PadTensorOp::createPadScalarOp( RankedTensorType::get(paddedShape, inputETy), input, padValue, @@ -109,30 +110,30 @@ // tosa::AbsOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); if (isa(op) && elementTy.isa()) { - auto zero = - rewriter.create(loc, rewriter.getZeroAttr(elementTy)); - auto cmp = - rewriter.create(loc, CmpIPredicate::sgt, args[0], zero); - auto neg = rewriter.create(loc, zero, args[0]); + auto zero = rewriter.create( + loc, rewriter.getZeroAttr(elementTy)); + auto cmp = rewriter.create(loc, arith::CmpIPredicate::sgt, + args[0], zero); + auto neg = rewriter.create(loc, zero, args[0]); return rewriter.create(loc, cmp, args[0], neg); } // tosa::AddOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::SubOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::MulOp if (isa(op) && elementTy.isa()) { @@ -141,18 +142,18 @@ "Cannot have shift value for float"); return nullptr; } - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); } // tosa::DivOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::ReciprocalOp if (isa(op) && elementTy.isa()) { auto one = - rewriter.create(loc, FloatAttr::get(elementTy, 1)); - return rewriter.create(loc, resultTypes, one, args[0]); + rewriter.create(loc, FloatAttr::get(elementTy, 1)); + return rewriter.create(loc, resultTypes, one, args[0]); } if (isa(op) && elementTy.isa()) { @@ -162,12 +163,12 @@ op->getAttr("shift").cast().getValue().getSExtValue(); if (shift > 0) { auto shiftConst = - rewriter.create(loc, shift, /*bitwidth=*/8); + rewriter.create(loc, shift, /*bitwidth=*/8); if (!a.getType().isInteger(32)) - a = rewriter.create(loc, rewriter.getI32Type(), a); + a = rewriter.create(loc, rewriter.getI32Type(), a); if (!b.getType().isInteger(32)) - b = rewriter.create(loc, rewriter.getI32Type(), b); + b = rewriter.create(loc, rewriter.getI32Type(), b); auto result = rewriter.create( loc, rewriter.getI32Type(), a, b, shiftConst, @@ -176,7 +177,7 @@ if (elementTy.isInteger(32)) return result; - return rewriter.create(loc, elementTy, result); + return rewriter.create(loc, elementTy, result); } int aWidth = a.getType().getIntOrFloatBitWidth(); @@ -184,22 +185,22 @@ int cWidth = resultTypes[0].getIntOrFloatBitWidth(); if (aWidth < cWidth) - a = rewriter.create(loc, resultTypes[0], a); + a = rewriter.create(loc, resultTypes[0], a); if (bWidth < cWidth) - b = rewriter.create(loc, resultTypes[0], b); + b = rewriter.create(loc, resultTypes[0], b); - return rewriter.create(loc, resultTypes, a, b); + return rewriter.create(loc, resultTypes, a, b); } // tosa::NegateOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); if (isa(op) && elementTy.isa() && !cast(op).quantization_info()) { auto constant = - rewriter.create(loc, IntegerAttr::get(elementTy, 0)); - return rewriter.create(loc, resultTypes, constant, args[0]); + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + return rewriter.create(loc, resultTypes, constant, args[0]); } if (isa(op) && elementTy.isa() && @@ -228,62 +229,59 @@ } Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); - Value zpAddValue = rewriter.create( + Value zpAddValue = rewriter.create( loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); // The negation can be applied by doing: // outputValue = inZp + outZp - inputValue - auto ext = rewriter.create(loc, intermediateType, args[0]); - auto sub = rewriter.create(loc, zpAddValue, ext); + auto ext = rewriter.create(loc, intermediateType, args[0]); + auto sub = rewriter.create(loc, zpAddValue, ext); // Clamp to the negation range. - auto min = rewriter.create( - loc, rewriter.getIntegerAttr( - intermediateType, - APInt::getSignedMinValue(inputBitWidth).getSExtValue())); - auto max = rewriter.create( - loc, rewriter.getIntegerAttr( - intermediateType, - APInt::getSignedMaxValue(inputBitWidth).getSExtValue())); - auto clamp = clampHelper(loc, sub, min, max, - CmpIPredicate::slt, rewriter); + auto min = rewriter.create( + loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(), + intermediateType); + auto max = rewriter.create( + loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(), + intermediateType); + auto clamp = clampHelper( + loc, sub, min, max, arith::CmpIPredicate::slt, rewriter); // Truncate to the final value. - return rewriter.create(loc, elementTy, clamp); + return rewriter.create(loc, elementTy, clamp); } // tosa::BitwiseAndOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::BitwiseOrOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::BitwiseNotOp if (isa(op) && elementTy.isa()) { auto allOnesAttr = rewriter.getIntegerAttr( elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth())); - auto allOnes = rewriter.create(loc, allOnesAttr); - return rewriter.create(loc, resultTypes, args[0], allOnes); + auto allOnes = rewriter.create(loc, allOnesAttr); + return rewriter.create(loc, resultTypes, args[0], allOnes); } // tosa::BitwiseXOrOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::LogicalLeftShiftOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::LogicalRightShiftOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::ArithmeticRightShiftOp if (isa(op) && elementTy.isa()) { - auto result = - rewriter.create(loc, resultTypes, args); + auto result = rewriter.create(loc, resultTypes, args); auto round = op->getAttr("round").cast().getValue(); if (!round) { return result; @@ -291,40 +289,40 @@ Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1); auto one = - rewriter.create(loc, IntegerAttr::get(elementTy, 1)); + rewriter.create(loc, IntegerAttr::get(elementTy, 1)); auto zero = - rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); auto i1one = - rewriter.create(loc, IntegerAttr::get(i1Ty, 1)); + rewriter.create(loc, IntegerAttr::get(i1Ty, 1)); // Checking that input2 != 0 - auto shiftValueGreaterThanZero = - rewriter.create(loc, CmpIPredicate::sgt, args[1], zero); + auto shiftValueGreaterThanZero = rewriter.create( + loc, arith::CmpIPredicate::sgt, args[1], zero); // Checking for the last bit of input1 to be 1 auto subtract = - rewriter.create(loc, resultTypes, args[1], one); - auto shifted = rewriter - .create(loc, resultTypes, - args[0], subtract) - ->getResults(); + rewriter.create(loc, resultTypes, args[1], one); + auto shifted = + rewriter.create(loc, resultTypes, args[0], subtract) + ->getResults(); auto truncated = - rewriter.create(loc, i1Ty, shifted, mlir::None); - auto isInputOdd = rewriter.create(loc, i1Ty, truncated, i1one); + rewriter.create(loc, i1Ty, shifted, mlir::None); + auto isInputOdd = + rewriter.create(loc, i1Ty, truncated, i1one); - auto shouldRound = rewriter.create( + auto shouldRound = rewriter.create( loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); auto extended = - rewriter.create(loc, resultTypes, shouldRound); - return rewriter.create(loc, resultTypes, result, extended); + rewriter.create(loc, resultTypes, shouldRound); + return rewriter.create(loc, resultTypes, result, extended); } // tosa::ClzOp if (isa(op) && elementTy.isa()) { int bitWidth = elementTy.getIntOrFloatBitWidth(); auto zero = - rewriter.create(loc, IntegerAttr::get(elementTy, 0)); - auto leadingZeros = rewriter.create( + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + auto leadingZeros = rewriter.create( loc, IntegerAttr::get(elementTy, bitWidth)); SmallVector operands = {args[0], leadingZeros, zero}; @@ -340,8 +338,8 @@ Value input = before->getArgument(0); Value zero = before->getArgument(2); - Value inputLargerThanZero = - rewriter.create(loc, CmpIPredicate::ne, input, zero); + Value inputLargerThanZero = rewriter.create( + loc, arith::CmpIPredicate::ne, input, zero); rewriter.create(loc, inputLargerThanZero, before->getArguments()); } @@ -352,12 +350,12 @@ Value input = after->getArgument(0); Value leadingZeros = after->getArgument(1); - auto one = rewriter.create( + auto one = rewriter.create( loc, IntegerAttr::get(elementTy, 1)); - auto shifted = rewriter.create( - loc, resultTypes, input, one); + auto shifted = + rewriter.create(loc, resultTypes, input, one); auto leadingZerosMinusOne = - rewriter.create(loc, resultTypes, leadingZeros, one); + rewriter.create(loc, resultTypes, leadingZeros, one); rewriter.create( loc, @@ -370,22 +368,22 @@ // tosa::LogicalAnd if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::LogicalNot if (isa(op) && elementTy.isInteger(1)) { - auto one = rewriter.create( + auto one = rewriter.create( loc, rewriter.getIntegerAttr(elementTy, 1)); - return rewriter.create(loc, resultTypes, args[0], one); + return rewriter.create(loc, resultTypes, args[0], one); } // tosa::LogicalOr if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::LogicalXor if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::PowOp if (isa(op) && elementTy.isa()) @@ -409,30 +407,30 @@ // tosa::GreaterOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, CmpFPredicate::OGT, args[0], - args[1]); + return rewriter.create(loc, arith::CmpFPredicate::OGT, + args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) - return rewriter.create(loc, CmpIPredicate::sgt, args[0], - args[1]); + return rewriter.create(loc, arith::CmpIPredicate::sgt, + args[0], args[1]); // tosa::GreaterEqualOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, CmpFPredicate::OGE, args[0], - args[1]); + return rewriter.create(loc, arith::CmpFPredicate::OGE, + args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) - return rewriter.create(loc, CmpIPredicate::sge, args[0], - args[1]); + return rewriter.create(loc, arith::CmpIPredicate::sge, + args[0], args[1]); // tosa::EqualOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, CmpFPredicate::OEQ, args[0], - args[1]); + return rewriter.create(loc, arith::CmpFPredicate::OEQ, + args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) - return rewriter.create(loc, CmpIPredicate::eq, args[0], - args[1]); + return rewriter.create(loc, arith::CmpIPredicate::eq, + args[0], args[1]); // tosa::SelectOp if (isa(op)) { @@ -443,46 +441,46 @@ // tosa::MaximumOp if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpFPredicate::OGT, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpFPredicate::OGT, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isSignlessInteger()) { - auto predicate = rewriter.create(loc, CmpIPredicate::sgt, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpIPredicate::sgt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } // tosa::MinimumOp if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpFPredicate::OLT, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpFPredicate::OLT, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isSignlessInteger()) { - auto predicate = rewriter.create(loc, CmpIPredicate::slt, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpIPredicate::slt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } // tosa::CeilOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::FloorOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::ClampOp if (isa(op) && elementTy.isa()) { - auto min = rewriter.create(loc, elementTy, - op->getAttr("min_fp")); - auto max = rewriter.create(loc, elementTy, - op->getAttr("max_fp")); - return clampHelper(loc, args[0], min, max, CmpFPredicate::OLT, - rewriter); + auto min = rewriter.create(loc, elementTy, + op->getAttr("min_fp")); + auto max = rewriter.create(loc, elementTy, + op->getAttr("max_fp")); + return clampHelper(loc, args[0], min, max, + arith::CmpFPredicate::OLT, rewriter); } if (isa(op) && elementTy.isa()) { @@ -506,41 +504,41 @@ .getSExtValue()); } - auto minVal = - rewriter.create(loc, min, intTy.getIntOrFloatBitWidth()); - auto maxVal = - rewriter.create(loc, max, intTy.getIntOrFloatBitWidth()); - return clampHelper(loc, args[0], minVal, maxVal, - CmpIPredicate::slt, rewriter); + auto minVal = rewriter.create( + loc, min, intTy.getIntOrFloatBitWidth()); + auto maxVal = rewriter.create( + loc, max, intTy.getIntOrFloatBitWidth()); + return clampHelper(loc, args[0], minVal, maxVal, + arith::CmpIPredicate::slt, rewriter); } // tosa::ReluNOp if (isa(op) && elementTy.isa()) { auto zero = - rewriter.create(loc, FloatAttr::get(elementTy, 0)); - auto n = rewriter.create(loc, elementTy, - op->getAttr("max_fp")); - return clampHelper(loc, args[0], zero, n, CmpFPredicate::OLT, - rewriter); + rewriter.create(loc, FloatAttr::get(elementTy, 0)); + auto n = rewriter.create(loc, elementTy, + op->getAttr("max_fp")); + return clampHelper(loc, args[0], zero, n, + arith::CmpFPredicate::OLT, rewriter); } if (isa(op) && elementTy.isa()) { auto zero = - rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); auto n = createConstFromIntAttribute(op, "max_int", elementTy, rewriter); - return clampHelper(loc, args[0], zero, n, CmpIPredicate::slt, - rewriter); + return clampHelper(loc, args[0], zero, n, + arith::CmpIPredicate::slt, rewriter); } // tosa::SigmoidOp if (isa(op) && elementTy.isa()) { auto one = - rewriter.create(loc, FloatAttr::get(elementTy, 1)); - auto negate = rewriter.create(loc, resultTypes, args[0]); + rewriter.create(loc, FloatAttr::get(elementTy, 1)); + auto negate = rewriter.create(loc, resultTypes, args[0]); auto exp = rewriter.create(loc, resultTypes, negate); - auto added = rewriter.create(loc, resultTypes, exp, one); - return rewriter.create(loc, resultTypes, one, added); + auto added = rewriter.create(loc, resultTypes, exp, one); + return rewriter.create(loc, resultTypes, one, added); } // tosa::CastOp @@ -554,25 +552,25 @@ return args.front(); if (srcTy.isa() && dstTy.isa() && bitExtend) - return rewriter.create(loc, resultTypes, args, mlir::None); + return rewriter.create(loc, resultTypes, args, mlir::None); if (srcTy.isa() && dstTy.isa() && !bitExtend) - return rewriter.create(loc, resultTypes, args, + return rewriter.create(loc, resultTypes, args, mlir::None); // 1-bit integers need to be treated as signless. - if (srcTy.isInteger(1) && mlir::UIToFPOp::areCastCompatible(srcTy, dstTy)) - return rewriter.create(loc, resultTypes, args, - mlir::None); + if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy)) + return rewriter.create(loc, resultTypes, args, + mlir::None); if (srcTy.isInteger(1) && dstTy.isa() && bitExtend) - return rewriter.create(loc, resultTypes, args, - mlir::None); + return rewriter.create(loc, resultTypes, args, + mlir::None); // All other si-to-fp conversions should be handled by SIToFP. - if (mlir::SIToFPOp::areCastCompatible(srcTy, dstTy)) - return rewriter.create(loc, resultTypes, args, - mlir::None); + if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy)) + return rewriter.create(loc, resultTypes, args, + mlir::None); // Unsigned integers need an unrealized cast so that they can be passed // to UIToFP. @@ -583,76 +581,76 @@ loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0]) .getResult(0); - return rewriter.create(loc, resultTypes[0], - unrealizedCast); + return rewriter.create(loc, resultTypes[0], + unrealizedCast); } // Casting to boolean, floats need to only be checked as not-equal to zero. if (srcTy.isa() && dstTy.isInteger(1)) { - Value zero = - rewriter.create(loc, rewriter.getFloatAttr(srcTy, 0.0)); - return rewriter.create(loc, CmpFPredicate::UNE, - args.front(), zero); + Value zero = rewriter.create( + loc, rewriter.getFloatAttr(srcTy, 0.0)); + return rewriter.create(loc, arith::CmpFPredicate::UNE, + args.front(), zero); } - if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy)) { - auto zero = - rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); - auto half = - rewriter.create(loc, rewriter.getF32FloatAttr(0.5f)); + if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { + auto zero = rewriter.create( + loc, rewriter.getF32FloatAttr(0.0f)); + auto half = rewriter.create( + loc, rewriter.getF32FloatAttr(0.5f)); - auto intMin = rewriter.create( + auto intMin = rewriter.create( loc, rewriter.getF32FloatAttr( APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); - auto intMax = rewriter.create( + auto intMax = rewriter.create( loc, rewriter.getF32FloatAttr( APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); - auto added = rewriter.create(loc, args[0], half); - auto subbed = rewriter.create(loc, args[0], half); - auto negative = - rewriter.create(loc, CmpFPredicate::OLT, args[0], zero); + auto added = rewriter.create(loc, args[0], half); + auto subbed = rewriter.create(loc, args[0], half); + auto negative = rewriter.create( + loc, arith::CmpFPredicate::OLT, args[0], zero); auto rounded = rewriter.create(loc, negative, subbed, added); - auto clamped = clampHelper(loc, rounded, intMin, intMax, - CmpFPredicate::OLT, rewriter); + auto clamped = clampHelper( + loc, rounded, intMin, intMax, arith::CmpFPredicate::OLT, rewriter); - return rewriter.create(loc, dstTy, clamped); + return rewriter.create(loc, dstTy, clamped); } // Casting to boolean, integers need to only be checked as not-equal to // zero. if (srcTy.isa() && dstTy.isInteger(1)) { - Value zero = - rewriter.create(loc, 0, srcTy.getIntOrFloatBitWidth()); - return rewriter.create(loc, CmpIPredicate::ne, args.front(), - zero); + Value zero = rewriter.create( + loc, 0, srcTy.getIntOrFloatBitWidth()); + return rewriter.create(loc, arith::CmpIPredicate::ne, + args.front(), zero); } if (srcTy.isa() && dstTy.isa() && bitExtend) - return rewriter.create(loc, resultTypes, args, - mlir::None); + return rewriter.create(loc, resultTypes, args, + mlir::None); if (srcTy.isa() && dstTy.isa() && !bitExtend) { - auto intMin = rewriter.create( + auto intMin = rewriter.create( loc, APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue(), srcTy.getIntOrFloatBitWidth()); - auto intMax = rewriter.create( + auto intMax = rewriter.create( loc, APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue(), srcTy.getIntOrFloatBitWidth()); - auto clamped = clampHelper(loc, args[0], intMin, intMax, - CmpIPredicate::slt, rewriter); - return rewriter.create(loc, dstTy, clamped); + auto clamped = clampHelper( + loc, args[0], intMin, intMax, arith::CmpIPredicate::slt, rewriter); + return rewriter.create(loc, dstTy, clamped); } } @@ -832,50 +830,50 @@ PatternRewriter &rewriter) { Location loc = op->getLoc(); if (isa(op) && elementTy.isa()) { - return rewriter.create(loc, args); + return rewriter.create(loc, args); } if (isa(op) && elementTy.isa()) { - return rewriter.create(loc, args); + return rewriter.create(loc, args); } if (isa(op) && elementTy.isa()) { - return rewriter.create(loc, args); + return rewriter.create(loc, args); } if (isa(op) && elementTy.isa()) { - return rewriter.create(loc, args); + return rewriter.create(loc, args); } if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpFPredicate::OLT, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpFPredicate::OLT, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpIPredicate::slt, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpIPredicate::slt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpFPredicate::OGT, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpFPredicate::OGT, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpIPredicate::sgt, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpIPredicate::sgt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, args); + return rewriter.create(loc, args); if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, args); + return rewriter.create(loc, args); return {}; } @@ -911,7 +909,7 @@ return rewriter.notifyMatchFailure( op, "No initial value found for reduction operation"); - auto fillValue = rewriter.create(loc, fillValueAttr); + auto fillValue = rewriter.create(loc, fillValueAttr); auto filledTensor = rewriter.create(loc, fillValue, initTensor).result(); @@ -1032,7 +1030,8 @@ weightShape[3], weightShape[0]}; auto weightPermAttr = DenseIntElementsAttr::get( RankedTensorType::get({4}, rewriter.getI64Type()), weightPerm); - Value weightPermValue = rewriter.create(loc, weightPermAttr); + Value weightPermValue = + rewriter.create(loc, weightPermAttr); Type newWeightTy = RankedTensorType::get(newWeightShape, weightTy.getElementType()); weight = rewriter.create(loc, newWeightTy, weight, @@ -1041,7 +1040,7 @@ Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); Value initTensor = rewriter.create( loc, resultTy.getShape(), resultETy); - Value zero = rewriter.create(loc, resultZeroAttr); + Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter.create(loc, zero, initTensor).getResult(0); @@ -1075,8 +1074,8 @@ auto kZp = rewriter.getI32IntegerAttr( quantizationInfo.weight_zp().getValue().getSExtValue()); - auto iZpVal = rewriter.create(loc, iZp); - auto kZpVal = rewriter.create(loc, kZp); + auto iZpVal = rewriter.create(loc, iZp); + auto kZpVal = rewriter.create(loc, kZp); Value conv = rewriter .create( @@ -1091,8 +1090,8 @@ indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); @@ -1113,8 +1112,8 @@ indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); @@ -1223,7 +1222,7 @@ Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); Value initTensor = rewriter.create( loc, linalgConvTy.getShape(), resultETy); - Value zero = rewriter.create(loc, resultZeroAttr); + Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter.create(loc, zero, initTensor).getResult(0); @@ -1244,15 +1243,15 @@ getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); rewriter.replaceOp(op, result); } else { - auto iZpVal = rewriter.create(loc, iZp); - auto kZpVal = rewriter.create(loc, kZp); + auto iZpVal = rewriter.create(loc, iZp); + auto kZpVal = rewriter.create(loc, kZp); Value conv = rewriter .create( @@ -1268,8 +1267,8 @@ getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); @@ -1382,7 +1381,7 @@ SmallVector filteredDims = filterDynamicDims(dynDims); auto zeroAttr = rewriter.getZeroAttr(outputElementTy); - Value zero = rewriter.create(loc, zeroAttr); + Value zero = rewriter.create(loc, zeroAttr); auto initTensor = rewriter.create( loc, filteredDims, outputTy.getShape(), outputTy.getElementType()); Value zeroTensor = @@ -1395,10 +1394,10 @@ } auto quantizationInfo = op.quantization_info().getValue(); - auto aZp = rewriter.create( + auto aZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.a_zp().getValue().getSExtValue())); - auto bZp = rewriter.create( + auto bZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.b_zp().getValue().getSExtValue())); rewriter.replaceOpWithNewOp( @@ -1458,14 +1457,15 @@ // When quantized, the input elemeny type is not the same as the output Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy); - Value zero = rewriter.create(loc, resultZeroAttr); + Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter.create(loc, zero, initTensor).getResult(0); SmallVector permutation{1, 0}; auto permutationAttr = DenseIntElementsAttr::get( RankedTensorType::get({2}, rewriter.getI64Type()), permutation); - Value permutationValue = rewriter.create(loc, permutationAttr); + Value permutationValue = + rewriter.create(loc, permutationAttr); SmallVector newWeightShape{weightShape[1], weightShape[0]}; Type newWeightTy = @@ -1494,8 +1494,8 @@ indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); @@ -1504,10 +1504,10 @@ } auto quantizationInfo = op.quantization_info().getValue(); - auto inputZp = rewriter.create( + auto inputZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.input_zp().getValue().getSExtValue())); - auto outputZp = rewriter.create( + auto outputZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.weight_zp().getValue().getSExtValue())); Value matmul = @@ -1524,8 +1524,8 @@ indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); @@ -1738,7 +1738,7 @@ Value multiplierConstant; int64_t multiplierArg = 0; if (multiplierValues.size() == 1) { - multiplierConstant = rewriter.create( + multiplierConstant = rewriter.create( loc, rewriter.getI32IntegerAttr(multiplierValues.front())); } else { SmallVector multiplierExprs{ @@ -1746,7 +1746,7 @@ auto multiplierType = RankedTensorType::get({static_cast(multiplierValues.size())}, rewriter.getI32Type()); - genericInputs.push_back(rewriter.create( + genericInputs.push_back(rewriter.create( loc, DenseIntElementsAttr::get(multiplierType, multiplierValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, @@ -1761,7 +1761,7 @@ Value shiftConstant; int64_t shiftArg = 0; if (shiftValues.size() == 1) { - shiftConstant = rewriter.create( + shiftConstant = rewriter.create( loc, rewriter.getI8IntegerAttr(shiftValues.front())); } else { SmallVector shiftExprs = { @@ -1769,7 +1769,7 @@ auto shiftType = RankedTensorType::get({static_cast(shiftValues.size())}, rewriter.getIntegerType(8)); - genericInputs.push_back(rewriter.create( + genericInputs.push_back(rewriter.create( loc, DenseIntElementsAttr::get(shiftType, shiftValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, shiftExprs, @@ -1817,22 +1817,24 @@ valueTy.getIntOrFloatBitWidth()), value) .getResult(0); - value = nestedBuilder.create( + value = nestedBuilder.create( nestedLoc, nestedBuilder.getI32Type(), value); } else { - value = nestedBuilder.create( + value = nestedBuilder.create( nestedLoc, nestedBuilder.getI32Type(), value); } } - value = nestedBuilder.create(nestedLoc, value, inputZp); + value = + nestedBuilder.create(nestedLoc, value, inputZp); value = nestedBuilder.create( loc, nestedBuilder.getI32Type(), value, multiplier, shift, nestedBuilder.getBoolAttr(doubleRound)); // Move to the new zero-point. - value = nestedBuilder.create(nestedLoc, value, outputZp); + value = + nestedBuilder.create(nestedLoc, value, outputZp); // Saturate to the output size. IntegerType outIntType = @@ -1848,19 +1850,17 @@ intMax = APInt::getMaxValue(outBitWidth).getZExtValue(); } - auto intMinVal = nestedBuilder.create( - loc, - nestedBuilder.getIntegerAttr(nestedBuilder.getI32Type(), intMin)); - auto intMaxVal = nestedBuilder.create( - loc, - nestedBuilder.getIntegerAttr(nestedBuilder.getI32Type(), intMax)); + auto intMinVal = nestedBuilder.create( + loc, nestedBuilder.getI32IntegerAttr(intMin)); + auto intMaxVal = nestedBuilder.create( + loc, nestedBuilder.getI32IntegerAttr(intMax)); - value = - clampHelper(nestedLoc, value, intMinVal, intMaxVal, - CmpIPredicate::slt, nestedBuilder); + value = clampHelper( + nestedLoc, value, intMinVal, intMaxVal, arith::CmpIPredicate::slt, + nestedBuilder); if (outIntType.getWidth() < 32) { - value = nestedBuilder.create( + value = nestedBuilder.create( nestedLoc, rewriter.getIntegerType(outIntType.getWidth()), value); @@ -1923,37 +1923,39 @@ Value x = rewriter.create(loc, 2); Value channel = rewriter.create(loc, 3); - auto hwMin = - rewriter.create(loc, rewriter.getI32IntegerAttr(0)); - auto hMax = rewriter.create( + auto hwMin = rewriter.create( + loc, rewriter.getI32IntegerAttr(0)); + auto hMax = rewriter.create( loc, rewriter.getI32IntegerAttr(imageH - 1)); - auto wMax = rewriter.create( + auto wMax = rewriter.create( loc, rewriter.getI32IntegerAttr(imageW - 1)); - Value inY = rewriter.create(loc, rewriter.getI32Type(), y); - Value inX = rewriter.create(loc, rewriter.getI32Type(), x); + Value inY = + rewriter.create(loc, rewriter.getI32Type(), y); + Value inX = + rewriter.create(loc, rewriter.getI32Type(), x); int32_t shift = op.shift(); bool floatingPointMode = shift == 0; Value yStride, xStride, yOffset, xOffset; if (floatingPointMode) { - yStride = rewriter.create(loc, op.stride_fp()[0]); - xStride = rewriter.create(loc, op.stride_fp()[1]); - yOffset = rewriter.create(loc, op.offset_fp()[0]); - xOffset = rewriter.create(loc, op.offset_fp()[1]); + yStride = rewriter.create(loc, op.stride_fp()[0]); + xStride = rewriter.create(loc, op.stride_fp()[1]); + yOffset = rewriter.create(loc, op.offset_fp()[0]); + xOffset = rewriter.create(loc, op.offset_fp()[1]); } else { SmallVector stride, offset; getValuesFromIntArrayAttribute(op.stride(), stride); getValuesFromIntArrayAttribute(op.offset(), offset); - yStride = rewriter.create( + yStride = rewriter.create( loc, rewriter.getI32IntegerAttr(stride[0])); - xStride = rewriter.create( + xStride = rewriter.create( loc, rewriter.getI32IntegerAttr(stride[1])); - yOffset = rewriter.create( + yOffset = rewriter.create( loc, rewriter.getI32IntegerAttr(offset[0])); - xOffset = rewriter.create( + xOffset = rewriter.create( loc, rewriter.getI32IntegerAttr(offset[1])); } @@ -1963,85 +1965,89 @@ // dx = x - ix Value ix, iy, dx, dy; if (floatingPointMode) { - Value y = rewriter.create(loc, rewriter.getF32Type(), inY); - Value x = rewriter.create(loc, rewriter.getF32Type(), inX); + Value y = + rewriter.create(loc, rewriter.getF32Type(), inY); + Value x = + rewriter.create(loc, rewriter.getF32Type(), inX); - y = rewriter.create(loc, y, yStride); - x = rewriter.create(loc, x, xStride); + y = rewriter.create(loc, y, yStride); + x = rewriter.create(loc, x, xStride); - y = rewriter.create(loc, y, yOffset); - x = rewriter.create(loc, x, xOffset); + y = rewriter.create(loc, y, yOffset); + x = rewriter.create(loc, x, xOffset); - iy = rewriter.create(loc, y); - ix = rewriter.create(loc, x); + iy = rewriter.create(loc, y); + ix = rewriter.create(loc, x); - dy = rewriter.create(loc, y, iy); - dx = rewriter.create(loc, x, ix); + dy = rewriter.create(loc, y, iy); + dx = rewriter.create(loc, x, ix); - iy = rewriter.create(loc, rewriter.getI32Type(), iy); - ix = rewriter.create(loc, rewriter.getI32Type(), ix); + iy = rewriter.create(loc, rewriter.getI32Type(), iy); + ix = rewriter.create(loc, rewriter.getI32Type(), ix); } else { - Value shiftVal = - rewriter.create(loc, rewriter.getI32IntegerAttr(shift)); + Value shiftVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(shift)); - Value y = rewriter.create(loc, inY, yStride); - Value x = rewriter.create(loc, inX, xStride); + Value y = rewriter.create(loc, inY, yStride); + Value x = rewriter.create(loc, inX, xStride); - y = rewriter.create(loc, y, yOffset); - x = rewriter.create(loc, x, xOffset); + y = rewriter.create(loc, y, yOffset); + x = rewriter.create(loc, x, xOffset); - iy = rewriter.create(loc, y, shiftVal); - ix = rewriter.create(loc, x, shiftVal); + iy = rewriter.create(loc, y, shiftVal); + ix = rewriter.create(loc, x, shiftVal); - Value yTrunc = rewriter.create(loc, iy, shiftVal); - Value xTrunc = rewriter.create(loc, ix, shiftVal); + Value yTrunc = rewriter.create(loc, iy, shiftVal); + Value xTrunc = rewriter.create(loc, ix, shiftVal); - dy = rewriter.create(loc, y, yTrunc); - dx = rewriter.create(loc, x, xTrunc); + dy = rewriter.create(loc, y, yTrunc); + dx = rewriter.create(loc, x, xTrunc); } if (op.mode() == "NEAREST_NEIGHBOR") { Value yPred, xPred; // Round the index position towards the closest pixel location. if (floatingPointMode) { - auto halfVal = - rewriter.create(loc, rewriter.getF32FloatAttr(0.5f)); - yPred = rewriter.create(loc, CmpFPredicate::OGE, dy, - halfVal); - xPred = rewriter.create(loc, CmpFPredicate::OGE, dx, - halfVal); + auto halfVal = rewriter.create( + loc, rewriter.getF32FloatAttr(0.5f)); + yPred = rewriter.create(loc, arith::CmpFPredicate::OGE, + dy, halfVal); + xPred = rewriter.create(loc, arith::CmpFPredicate::OGE, + dx, halfVal); } else { - auto halfVal = rewriter.create( + auto halfVal = rewriter.create( loc, rewriter.getI32IntegerAttr(1 << (shift - 1))); - yPred = rewriter.create(loc, CmpIPredicate::sge, dy, - halfVal); - xPred = rewriter.create(loc, CmpIPredicate::sge, dx, - halfVal); + yPred = rewriter.create(loc, arith::CmpIPredicate::sge, + dy, halfVal); + xPred = rewriter.create(loc, arith::CmpIPredicate::sge, + dx, halfVal); } - auto zeroVal = - rewriter.create(loc, rewriter.getI32IntegerAttr(0)); - auto oneVal = - rewriter.create(loc, rewriter.getI32IntegerAttr(1)); + auto zeroVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(0)); + auto oneVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(1)); auto yOffset = rewriter.create(loc, yPred, oneVal, zeroVal); auto xOffset = rewriter.create(loc, xPred, oneVal, zeroVal); - iy = rewriter.create(loc, iy, yOffset); - ix = rewriter.create(loc, ix, xOffset); + iy = rewriter.create(loc, iy, yOffset); + ix = rewriter.create(loc, ix, xOffset); // Clamp the to be within the bounds of the input image. - iy = clampHelper(loc, iy, hwMin, hMax, CmpIPredicate::slt, - rewriter); - ix = clampHelper(loc, ix, hwMin, wMax, CmpIPredicate::slt, - rewriter); + iy = clampHelper(loc, iy, hwMin, hMax, + arith::CmpIPredicate::slt, rewriter); + ix = clampHelper(loc, ix, hwMin, wMax, + arith::CmpIPredicate::slt, rewriter); // Read the value from the input array. - iy = rewriter.create(loc, rewriter.getIndexType(), iy); - ix = rewriter.create(loc, rewriter.getIndexType(), ix); + iy = rewriter.create(loc, rewriter.getIndexType(), + iy); + ix = rewriter.create(loc, rewriter.getIndexType(), + ix); Value result = rewriter.create( loc, input, ValueRange{batch, iy, ix, channel}); @@ -2055,25 +2061,29 @@ Value y0 = iy; Value x0 = ix; - auto oneVal = - rewriter.create(loc, rewriter.getI32IntegerAttr(1)); - Value y1 = rewriter.create(loc, y0, oneVal); - Value x1 = rewriter.create(loc, x0, oneVal); - - y0 = clampHelper(loc, y0, hwMin, hMax, CmpIPredicate::slt, - rewriter); - y1 = clampHelper(loc, y1, hwMin, hMax, CmpIPredicate::slt, - rewriter); - - x0 = clampHelper(loc, x0, hwMin, wMax, CmpIPredicate::slt, - rewriter); - x1 = clampHelper(loc, x1, hwMin, wMax, CmpIPredicate::slt, - rewriter); - - y0 = rewriter.create(loc, rewriter.getIndexType(), y0); - y1 = rewriter.create(loc, rewriter.getIndexType(), y1); - x0 = rewriter.create(loc, rewriter.getIndexType(), x0); - x1 = rewriter.create(loc, rewriter.getIndexType(), x1); + auto oneVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(1)); + Value y1 = rewriter.create(loc, y0, oneVal); + Value x1 = rewriter.create(loc, x0, oneVal); + + y0 = clampHelper(loc, y0, hwMin, hMax, + arith::CmpIPredicate::slt, rewriter); + y1 = clampHelper(loc, y1, hwMin, hMax, + arith::CmpIPredicate::slt, rewriter); + + x0 = clampHelper(loc, x0, hwMin, wMax, + arith::CmpIPredicate::slt, rewriter); + x1 = clampHelper(loc, x1, hwMin, wMax, + arith::CmpIPredicate::slt, rewriter); + + y0 = rewriter.create(loc, rewriter.getIndexType(), + y0); + y1 = rewriter.create(loc, rewriter.getIndexType(), + y1); + x0 = rewriter.create(loc, rewriter.getIndexType(), + x0); + x1 = rewriter.create(loc, rewriter.getIndexType(), + x1); Value y0x0 = rewriter.create( loc, input, ValueRange{batch, y0, x0, channel}); @@ -2085,56 +2095,58 @@ loc, input, ValueRange{batch, y1, x1, channel}); if (floatingPointMode) { - auto oneVal = - rewriter.create(loc, rewriter.getF32FloatAttr(1.f)); + auto oneVal = rewriter.create( + loc, rewriter.getF32FloatAttr(1.f)); Value rightPart = dx; - Value leftPart = rewriter.create(loc, oneVal, dx); + Value leftPart = rewriter.create(loc, oneVal, dx); - y0x0 = rewriter.create(loc, y0x0, leftPart); - y0x1 = rewriter.create(loc, y0x1, rightPart); - Value topAcc = rewriter.create(loc, y0x0, y0x1); + y0x0 = rewriter.create(loc, y0x0, leftPart); + y0x1 = rewriter.create(loc, y0x1, rightPart); + Value topAcc = rewriter.create(loc, y0x0, y0x1); - y1x0 = rewriter.create(loc, y1x0, leftPart); - y1x1 = rewriter.create(loc, y1x1, rightPart); - Value bottomAcc = rewriter.create(loc, y1x0, y1x1); + y1x0 = rewriter.create(loc, y1x0, leftPart); + y1x1 = rewriter.create(loc, y1x1, rightPart); + Value bottomAcc = rewriter.create(loc, y1x0, y1x1); Value bottomPart = dy; - Value topPart = rewriter.create(loc, oneVal, dy); - topAcc = rewriter.create(loc, topAcc, topPart); - bottomAcc = rewriter.create(loc, bottomAcc, bottomPart); - Value result = rewriter.create(loc, topAcc, bottomAcc); + Value topPart = rewriter.create(loc, oneVal, dy); + topAcc = rewriter.create(loc, topAcc, topPart); + bottomAcc = + rewriter.create(loc, bottomAcc, bottomPart); + Value result = rewriter.create(loc, topAcc, bottomAcc); rewriter.create(loc, result); return success(); } else { - y0x0 = rewriter.create(loc, resultElementTy, y0x0); - y0x1 = rewriter.create(loc, resultElementTy, y0x1); - y1x0 = rewriter.create(loc, resultElementTy, y1x0); - y1x1 = rewriter.create(loc, resultElementTy, y1x1); + y0x0 = rewriter.create(loc, resultElementTy, y0x0); + y0x1 = rewriter.create(loc, resultElementTy, y0x1); + y1x0 = rewriter.create(loc, resultElementTy, y1x0); + y1x1 = rewriter.create(loc, resultElementTy, y1x1); if (resultElementTy.getIntOrFloatBitWidth() > 32) { - dx = rewriter.create(loc, resultElementTy, dx); - dy = rewriter.create(loc, resultElementTy, dy); + dx = rewriter.create(loc, resultElementTy, dx); + dy = rewriter.create(loc, resultElementTy, dy); } - auto unitVal = rewriter.create( + auto unitVal = rewriter.create( loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift)); Value rightPart = dx; - Value leftPart = rewriter.create(loc, unitVal, dx); + Value leftPart = rewriter.create(loc, unitVal, dx); - y0x0 = rewriter.create(loc, y0x0, leftPart); - y0x1 = rewriter.create(loc, y0x1, rightPart); - Value topAcc = rewriter.create(loc, y0x0, y0x1); + y0x0 = rewriter.create(loc, y0x0, leftPart); + y0x1 = rewriter.create(loc, y0x1, rightPart); + Value topAcc = rewriter.create(loc, y0x0, y0x1); - y1x0 = rewriter.create(loc, y1x0, leftPart); - y1x1 = rewriter.create(loc, y1x1, rightPart); - Value bottomAcc = rewriter.create(loc, y1x0, y1x1); + y1x0 = rewriter.create(loc, y1x0, leftPart); + y1x1 = rewriter.create(loc, y1x1, rightPart); + Value bottomAcc = rewriter.create(loc, y1x0, y1x1); Value bottomPart = dy; - Value topPart = rewriter.create(loc, unitVal, dy); - topAcc = rewriter.create(loc, topAcc, topPart); - bottomAcc = rewriter.create(loc, bottomAcc, bottomPart); - Value result = rewriter.create(loc, topAcc, bottomAcc); + Value topPart = rewriter.create(loc, unitVal, dy); + topAcc = rewriter.create(loc, topAcc, topPart); + bottomAcc = + rewriter.create(loc, bottomAcc, bottomPart); + Value result = rewriter.create(loc, topAcc, bottomAcc); rewriter.create(loc, result); return success(); @@ -2189,12 +2201,12 @@ Location loc = op.getLoc(); int axis = op.axis(); Value axisValue = - rewriter.create(loc, rewriter.getIndexAttr(axis)); + rewriter.create(loc, rewriter.getIndexAttr(axis)); int rank = resultType.getRank(); SmallVector offsets, sizes, strides; sizes.reserve(rank); - strides.resize(rank, rewriter.create(loc, 1)); - offsets.resize(rank, rewriter.create(loc, 0)); + strides.resize(rank, rewriter.create(loc, 1)); + offsets.resize(rank, rewriter.create(loc, 0)); for (int i = 0; i < rank; ++i) { sizes.push_back( @@ -2204,14 +2216,14 @@ Value resultDimSize = sizes[axis]; for (auto arg : adaptor.getOperands().drop_front()) { auto size = rewriter.create(loc, arg, axisValue); - resultDimSize = rewriter.create(loc, resultDimSize, size); + resultDimSize = rewriter.create(loc, resultDimSize, size); } sizes[axis] = resultDimSize; Value init = rewriter.create( loc, resultType.getShape(), resultType.getElementType()); - Value zeroVal = rewriter.create( + Value zeroVal = rewriter.create( loc, rewriter.getZeroAttr(resultType.getElementType())); Value result = rewriter.create(loc, zeroVal, init).getResult(0); @@ -2220,7 +2232,8 @@ sizes[axis] = rewriter.create(loc, arg, axisValue); result = rewriter.create(loc, arg, result, offsets, sizes, strides); - offsets[axis] = rewriter.create(loc, offsets[axis], sizes[axis]); + offsets[axis] = + rewriter.create(loc, offsets[axis], sizes[axis]); } rewriter.replaceOp(op, result); return success(); @@ -2266,10 +2279,11 @@ auto index = rewriter.create(nestedLoc, i).getResult(); if (i == axis) { - auto one = rewriter.create(nestedLoc, 1); + auto one = rewriter.create(nestedLoc, 1); auto sizeMinusOne = - rewriter.create(nestedLoc, axisDimSize, one); - index = rewriter.create(nestedLoc, sizeMinusOne, index); + rewriter.create(nestedLoc, axisDimSize, one); + index = rewriter.create(nestedLoc, sizeMinusOne, + index); } indices.push_back(index); @@ -2383,9 +2397,10 @@ "tosa.pad to linalg lowering encountered an unknown element type"); } - Value lowIndex = rewriter.create(loc, rewriter.getIndexAttr(0)); + Value lowIndex = + rewriter.create(loc, rewriter.getIndexAttr(0)); Value highIndex = - rewriter.create(loc, rewriter.getIndexAttr(1)); + rewriter.create(loc, rewriter.getIndexAttr(1)); SmallVector lowValues; SmallVector highValues; @@ -2394,22 +2409,22 @@ highValues.reserve(rank); for (int i = 0; i < rank; i++) { - Value inputIndex = rewriter.createOrFold(loc, i); + Value inputIndex = rewriter.createOrFold(loc, i); Value lowVal = rewriter.createOrFold( loc, padding, ValueRange({inputIndex, lowIndex})); Value highVal = rewriter.createOrFold( loc, padding, ValueRange({inputIndex, highIndex})); - lowVal = rewriter.createOrFold(loc, rewriter.getIndexType(), - lowVal); - highVal = rewriter.createOrFold(loc, rewriter.getIndexType(), - highVal); + lowVal = rewriter.createOrFold( + loc, rewriter.getIndexType(), lowVal); + highVal = rewriter.createOrFold( + loc, rewriter.getIndexType(), highVal); lowValues.push_back(lowVal); highValues.push_back(highVal); } - Value constant = rewriter.create(loc, constantAttr); + Value constant = rewriter.create(loc, constantAttr); auto newPadOp = linalg::PadTensorOp::createPadScalarOp( padOp.getType(), input, constant, lowValues, highValues, @@ -2464,7 +2479,7 @@ .create(loc, ArrayRef({}), resultTy.getShape(), outElementTy) .result(); - auto fillValueIdx = rewriter.create( + auto fillValueIdx = rewriter.create( loc, rewriter.getIntegerAttr(outElementTy, 0)); auto filledTensorIdx = rewriter.create(loc, fillValueIdx, initTensorIdx) @@ -2483,7 +2498,8 @@ return rewriter.notifyMatchFailure( argmaxOp, "unsupported tosa.argmax element type"); - auto fillValueMax = rewriter.create(loc, fillValueMaxAttr); + auto fillValueMax = + rewriter.create(loc, fillValueMaxAttr); auto filledTensorMax = rewriter.create(loc, fillValueMax, initTensorMax) .result(); @@ -2513,17 +2529,17 @@ auto oldIndex = blockArgs[1]; auto oldValue = blockArgs[2]; - Value newIndex = rewriter.create( + Value newIndex = rewriter.create( nestedLoc, oldIndex.getType(), rewriter.create(loc, axis)); Value predicate; if (inElementTy.isa()) { - predicate = rewriter.create( - nestedLoc, CmpFPredicate::OGT, newValue, oldValue); + predicate = rewriter.create( + nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); } else if (inElementTy.isa()) { - predicate = rewriter.create( - nestedLoc, CmpIPredicate::sgt, newValue, oldValue); + predicate = rewriter.create( + nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); } else { didEncounterError = true; return; @@ -2587,7 +2603,7 @@ [&](OpBuilder &b, Location loc, ValueRange args) { auto indexValue = args[0]; auto index0 = rewriter.create(loc, 0); - Value index1 = rewriter.create( + Value index1 = rewriter.create( loc, rewriter.getIndexType(), indexValue); auto index2 = rewriter.create(loc, 2); Value extract = rewriter.create( @@ -2648,11 +2664,11 @@ rewriter.setInsertionPointToStart(block); if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) && resultElementTy.isInteger(8)) { - Value index = rewriter.create(loc, rewriter.getIndexType(), - inputValue); - Value offset = rewriter.create(loc, 128); - index = rewriter.create(loc, rewriter.getIndexType(), index, - offset); + Value index = rewriter.create( + loc, rewriter.getIndexType(), inputValue); + Value offset = rewriter.create(loc, 128); + index = rewriter.create(loc, rewriter.getIndexType(), + index, offset); Value extract = rewriter.create(loc, table, ValueRange{index}); rewriter.create(loc, extract); @@ -2661,35 +2677,35 @@ if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) && resultElementTy.isInteger(32)) { - Value extend = rewriter.create( + Value extend = rewriter.create( loc, rewriter.getI32Type(), inputValue); - auto offset = - rewriter.create(loc, rewriter.getI32IntegerAttr(32768)); - auto seven = - rewriter.create(loc, rewriter.getI32IntegerAttr(7)); - auto one = - rewriter.create(loc, rewriter.getI32IntegerAttr(1)); - auto b1111111 = - rewriter.create(loc, rewriter.getI32IntegerAttr(127)); + auto offset = rewriter.create( + loc, rewriter.getI32IntegerAttr(32768)); + auto seven = rewriter.create( + loc, rewriter.getI32IntegerAttr(7)); + auto one = rewriter.create( + loc, rewriter.getI32IntegerAttr(1)); + auto b1111111 = rewriter.create( + loc, rewriter.getI32IntegerAttr(127)); // Compute the index and fractional part from the input value: // value = value + 32768 // index = value >> 7; // fraction = 0x01111111 & value - auto extendAdd = rewriter.create(loc, extend, offset); - Value index = - rewriter.create(loc, extendAdd, seven); - Value fraction = rewriter.create(loc, extendAdd, b1111111); + auto extendAdd = rewriter.create(loc, extend, offset); + Value index = rewriter.create(loc, extendAdd, seven); + Value fraction = + rewriter.create(loc, extendAdd, b1111111); // Extract the base and next values from the table. // base = (int32_t) table[index]; // next = (int32_t) table[index + 1]; - Value indexPlusOne = rewriter.create(loc, index, one); + Value indexPlusOne = rewriter.create(loc, index, one); - index = - rewriter.create(loc, rewriter.getIndexType(), index); - indexPlusOne = rewriter.create( + index = rewriter.create( + loc, rewriter.getIndexType(), index); + indexPlusOne = rewriter.create( loc, rewriter.getIndexType(), indexPlusOne); Value base = @@ -2697,15 +2713,18 @@ Value next = rewriter.create( loc, table, ValueRange{indexPlusOne}); - base = rewriter.create(loc, rewriter.getI32Type(), base); - next = rewriter.create(loc, rewriter.getI32Type(), next); + base = + rewriter.create(loc, rewriter.getI32Type(), base); + next = + rewriter.create(loc, rewriter.getI32Type(), next); // Use the fractional part to interpolate between the input values: // result = (base << 7) + (next - base) * fraction - Value baseScaled = rewriter.create(loc, base, seven); - Value diff = rewriter.create(loc, next, base); - Value diffScaled = rewriter.create(loc, diff, fraction); - Value result = rewriter.create(loc, baseScaled, diffScaled); + Value baseScaled = rewriter.create(loc, base, seven); + Value diff = rewriter.create(loc, next, base); + Value diffScaled = rewriter.create(loc, diff, fraction); + Value result = + rewriter.create(loc, baseScaled, diffScaled); rewriter.create(loc, result); @@ -2758,7 +2777,7 @@ pad.resize(pad.size() + 2, 0); Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); - Value initialValue = rewriter.create(loc, initialAttr); + Value initialValue = rewriter.create(loc, initialAttr); SmallVector kernel, stride; getValuesFromIntArrayAttribute(op.kernel(), kernel); @@ -2813,7 +2832,7 @@ Attribute initialAttr = rewriter.getZeroAttr(accETy); Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); - Value initialValue = rewriter.create(loc, initialAttr); + Value initialValue = rewriter.create(loc, initialAttr); SmallVector kernel, stride; getValuesFromIntArrayAttribute(op.kernel(), kernel); @@ -2855,18 +2874,18 @@ ArrayRef({affineMap, affineMap}), getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &b, Location loc, ValueRange args) { - auto zero = rewriter.create(loc, 0); - auto one = rewriter.create(loc, 1); - auto iH = rewriter.create( + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto iH = rewriter.create( loc, poolingOpTy.getDimSize(1) - 1); - auto iW = rewriter.create( + auto iW = rewriter.create( loc, poolingOpTy.getDimSize(2) - 1); // Compute the indices from either end. auto y0 = rewriter.create(loc, 1); auto x0 = rewriter.create(loc, 2); - auto y1 = rewriter.create(loc, iH, y0); - auto x1 = rewriter.create(loc, iW, x0); + auto y1 = rewriter.create(loc, iH, y0); + auto x1 = rewriter.create(loc, iW, x0); // Determines what the portion of valid input is covered by the // kernel. @@ -2874,34 +2893,34 @@ if (pad == 0) return v; - auto padVal = rewriter.create(loc, pad); - Value dx = rewriter.create(loc, x, padVal); + auto padVal = rewriter.create(loc, pad); + Value dx = rewriter.create(loc, x, padVal); - Value cmp = rewriter.create(loc, CmpIPredicate::slt, - dx, zero); + Value cmp = rewriter.create( + loc, arith::CmpIPredicate::slt, dx, zero); Value offset = rewriter.create(loc, cmp, dx, zero); - return rewriter.create(loc, v, offset)->getResult(0); + return rewriter.create(loc, v, offset)->getResult(0); }; // Compute the vertical component of coverage. - auto kH0 = rewriter.create(loc, kernel[0]); + auto kH0 = rewriter.create(loc, kernel[0]); auto kH1 = padFn(kH0, y0, pad[2]); auto kH2 = padFn(kH1, y1, pad[3]); - auto kHCmp = - rewriter.create(loc, CmpIPredicate::slt, kH2, one); + auto kHCmp = rewriter.create( + loc, arith::CmpIPredicate::slt, kH2, one); auto kH3 = rewriter.create(loc, kHCmp, one, kH2); // compute the horizontal component of coverage. - auto kW0 = rewriter.create(loc, kernel[1]); + auto kW0 = rewriter.create(loc, kernel[1]); auto kW1 = padFn(kW0, x0, pad[4]); auto kW2 = padFn(kW1, x1, pad[5]); - auto kWCmp = - rewriter.create(loc, CmpIPredicate::slt, kW2, one); + auto kWCmp = rewriter.create( + loc, arith::CmpIPredicate::slt, kW2, one); auto kW3 = rewriter.create(loc, kWCmp, one, kW2); // Compute the total number of elements and normalize. - Value count = rewriter.create(loc, kH3, kW3); - auto countI = rewriter.create( + Value count = rewriter.create(loc, kH3, kW3); + auto countI = rewriter.create( loc, rewriter.getI32Type(), count); // Divide by the number of summed values. For floats this is just @@ -2910,20 +2929,21 @@ Value poolVal = args[0]; if (accETy.isa()) { auto countF = - rewriter.create(loc, inElementTy, countI); - poolVal = - rewriter.create(loc, poolVal, countF)->getResult(0); + rewriter.create(loc, inElementTy, countI); + poolVal = rewriter.create(loc, poolVal, countF) + ->getResult(0); } else { // If we have quantization information we need to apply an offset // for the input zp value. if (op.quantization_info()) { auto quantizationInfo = op.quantization_info().getValue(); - auto inputZp = rewriter.create( + auto inputZp = rewriter.create( loc, quantizationInfo.input_zp()); Value offset = - rewriter.create(loc, accETy, countI, inputZp); - poolVal = rewriter.create(loc, accETy, poolVal, offset); + rewriter.create(loc, accETy, countI, inputZp); + poolVal = + rewriter.create(loc, accETy, poolVal, offset); } // Compute the multiplier and shift values for the quantization @@ -2933,14 +2953,14 @@ int64_t numerator = ((1 << 30) + 1); int64_t shift = 30; - Value numeratorVal = rewriter.create( + Value numeratorVal = rewriter.create( loc, rewriter.getI32IntegerAttr(numerator)); Value multiplierVal = rewriter - .create(loc, rewriter.getI32Type(), + .create(loc, rewriter.getI32Type(), numeratorVal, countI) .getResult(); - Value shiftVal = rewriter.create( + Value shiftVal = rewriter.create( loc, rewriter.getI8IntegerAttr(shift)); auto scaled = @@ -2954,28 +2974,26 @@ // zeropoint. if (op.quantization_info()) { auto quantizationInfo = op.quantization_info().getValue(); - auto outputZp = rewriter.create( + auto outputZp = rewriter.create( loc, quantizationInfo.output_zp()); - scaled = - rewriter.create(loc, scaled, outputZp).getResult(); + scaled = rewriter.create(loc, scaled, outputZp) + .getResult(); } // Apply Clip. int64_t outBitwidth = resultETy.getIntOrFloatBitWidth(); - auto min = rewriter.create( - loc, rewriter.getIntegerAttr( - accETy, - APInt::getSignedMinValue(outBitwidth).getSExtValue())); - auto max = rewriter.create( - loc, rewriter.getIntegerAttr( - accETy, - APInt::getSignedMaxValue(outBitwidth).getSExtValue())); - auto clamp = clampHelper( - loc, scaled, min, max, CmpIPredicate::slt, rewriter); + auto min = rewriter.create( + loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(), + accETy); + auto max = rewriter.create( + loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(), + accETy); + auto clamp = clampHelper( + loc, scaled, min, max, arith::CmpIPredicate::slt, rewriter); // Convert type. - poolVal = rewriter.create(loc, resultETy, clamp); + poolVal = rewriter.create(loc, resultETy, clamp); } // Cast to output type. diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -12,6 +12,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/SCF.h" @@ -33,9 +34,9 @@ : public TosaToLinalgOnTensorsBase { public: void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); + registry.insert(); } void runOnFunction() override { diff --git a/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt b/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt @@ -10,6 +10,7 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmetic MLIRIR MLIRStandard MLIRPass diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/TosaToStandard/TosaToStandard.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -28,7 +29,7 @@ LogicalResult matchAndRewrite(tosa::ConstOp op, PatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp<::ConstantOp>(op, op.value()); + rewriter.replaceOpWithNewOp(op, op.value()); return success(); } }; @@ -67,12 +68,12 @@ bool doubleRound = op.double_round(); Type inType = op.value().getType(); - Value one8 = rewriter.create( + Value one8 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(8), 1)); - Value one64 = rewriter.create( + Value one64 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); - Value shiftSubOne8 = rewriter.create(loc, shift8, one8); + Value shiftSubOne8 = rewriter.create(loc, shift8, one8); // The rounding value semantics below equate to the following code: // int64_t round = 1 << (shift - 1); @@ -83,45 +84,45 @@ // // Note that minimal bitwidth operators are used throughout the block. - Value round64 = rewriter.create( + Value round64 = rewriter.create( loc, one64, - rewriter.create(loc, rewriter.getI64Type(), - shiftSubOne8)); + rewriter.create(loc, rewriter.getI64Type(), + shiftSubOne8)); // Double rounding is performing a round operation before the shift if (doubleRound) { - Value one32 = rewriter.create( + Value one32 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 1)); - Value shift32 = rewriter.create( - loc, rewriter.getI32Type(), shift8); - Value thirty32 = rewriter.create( + Value shift32 = + rewriter.create(loc, rewriter.getI32Type(), shift8); + Value thirty32 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 30)); Value shiftThirty32 = - rewriter.create(loc, one32, thirty32); - Value shiftThirty64 = rewriter.create( + rewriter.create(loc, one32, thirty32); + Value shiftThirty64 = rewriter.create( loc, rewriter.getI64Type(), shiftThirty32); // Round value needs to with be added or subtracted depending on the sign // of the input value. Value roundAdd64 = - rewriter.create(loc, round64, shiftThirty64); + rewriter.create(loc, round64, shiftThirty64); Value roundSub64 = - rewriter.create(loc, round64, shiftThirty64); + rewriter.create(loc, round64, shiftThirty64); Value zero32 = - rewriter.create(loc, rewriter.getZeroAttr(inType)); - Value valueGreaterThanZero = rewriter.create( - loc, CmpIPredicate::sge, value32, zero32); + rewriter.create(loc, rewriter.getZeroAttr(inType)); + Value valueGreaterThanZero = rewriter.create( + loc, arith::CmpIPredicate::sge, value32, zero32); Value doubleRound64 = rewriter.create( loc, valueGreaterThanZero, roundAdd64, roundSub64); // We only perform double rounding if the shift value is greater than 32. - Value thirtyTwo32 = rewriter.create( + Value thirtyTwo32 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 32)); - Value shiftGreaterThanThirtyTwo = rewriter.create( - loc, CmpIPredicate::sge, shift32, thirtyTwo32); + Value shiftGreaterThanThirtyTwo = rewriter.create( + loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); round64 = rewriter.create(loc, shiftGreaterThanThirtyTwo, doubleRound64, round64); } @@ -133,20 +134,19 @@ // Note that multiply and shift need to be perform in i64 to preserve bits. Value value64 = - rewriter.create(loc, rewriter.getI64Type(), value32); - Value multiplier64 = rewriter.create( + rewriter.create(loc, rewriter.getI64Type(), value32); + Value multiplier64 = rewriter.create( loc, rewriter.getI64Type(), multiplier32); Value shift64 = - rewriter.create(loc, rewriter.getI64Type(), shift8); + rewriter.create(loc, rewriter.getI64Type(), shift8); // Multiply as a pair of i64 values to guarantee the end value fits. - Value result64 = rewriter.create(loc, value64, multiplier64); - result64 = rewriter.create(loc, result64, round64); - result64 = - rewriter.create(loc, result64, shift64); + Value result64 = rewriter.create(loc, value64, multiplier64); + result64 = rewriter.create(loc, result64, round64); + result64 = rewriter.create(loc, result64, shift64); - Value result32 = rewriter.create( - loc, rewriter.getI32Type(), result64); + Value result32 = + rewriter.create(loc, rewriter.getI32Type(), result64); rewriter.replaceOp(op, result32); return success(); diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp @@ -12,6 +12,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/TosaToStandard/TosaToStandard.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -34,6 +35,7 @@ target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); diff --git a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt @@ -8,6 +8,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRGPUOps MLIRLLVMIR MLIRMemRef diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -16,6 +16,7 @@ #include "../PassDetail.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -116,7 +117,7 @@ /// Return true if the constant is a splat to a 2D vector so that it can be /// converted to a MMA constant matrix op. -static bool constantSupportsMMAMatrixType(ConstantOp constantOp) { +static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { auto vecType = constantOp.getType().dyn_cast(); if (!vecType || vecType.getRank() != 2) return false; @@ -138,7 +139,7 @@ return transferWriteSupportsMMAMatrixType(transferWrite); if (auto contract = dyn_cast(op)) return contractSupportsMMAMatrixType(contract); - if (auto constant = dyn_cast(op)) + if (auto constant = dyn_cast(op)) return constantSupportsMMAMatrixType(constant); if (auto broadcast = dyn_cast(op)) return broadcastSupportsMMAMatrixType(broadcast); @@ -324,13 +325,13 @@ } /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. -static void convertConstantOp(ConstantOp op, +static void convertConstantOp(arith::ConstantOp op, llvm::DenseMap &valueMapping) { assert(constantSupportsMMAMatrixType(op)); OpBuilder b(op); - Attribute splat = op.getValue().cast().getSplatValue(); + Attribute splat = op.value().cast().getSplatValue(); auto scalarConstant = - b.create(op.getLoc(), splat.getType(), splat); + b.create(op.getLoc(), splat.getType(), splat); const char *fragType = inferFragType(op); auto vecType = op.getType().cast(); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( @@ -439,7 +440,7 @@ convertTransferWriteOp(transferWrite, valueMapping); } else if (auto contractOp = dyn_cast(op)) { convertContractOp(contractOp, valueMapping); - } else if (auto constantOp = dyn_cast(op)) { + } else if (auto constantOp = dyn_cast(op)) { convertConstantOp(constantOp, valueMapping); } else if (auto broadcastOp = dyn_cast(op)) { convertBroadcastOp(broadcastOp, valueMapping); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -13,6 +13,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRArmNeon MLIRArmSVE MLIRArmSVETransforms diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -59,7 +60,7 @@ return rewriter.create(loc, from, into, offset); return rewriter.create( loc, vectorType, from, into, - rewriter.create(loc, offset)); + rewriter.create(loc, offset)); } // Helper that picks the proper sequence for extracting. @@ -86,7 +87,7 @@ return rewriter.create(loc, vector, offset); return rewriter.create( loc, vectorType.getElementType(), vector, - rewriter.create(loc, offset)); + rewriter.create(loc, offset)); } // Helper that returns a subset of `arrayAttr` as a vector of int64_t. @@ -797,8 +798,8 @@ auto loc = op.getLoc(); auto elemType = vType.getElementType(); - Value zero = rewriter.create(loc, elemType, - rewriter.getZeroAttr(elemType)); + Value zero = rewriter.create( + loc, elemType, rewriter.getZeroAttr(elemType)); Value desc = rewriter.create(loc, vType, zero); for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { Value extrLHS = rewriter.create(loc, op.lhs(), i); @@ -1146,11 +1147,11 @@ if (rank == 0) { switch (conversion) { case PrintConversion::ZeroExt64: - value = rewriter.create( + value = rewriter.create( loc, value, IntegerType::get(rewriter.getContext(), 64)); break; case PrintConversion::SignExt64: - value = rewriter.create( + value = rewriter.create( loc, value, IntegerType::get(rewriter.getContext(), 64)); break; case PrintConversion::None: @@ -1233,8 +1234,8 @@ } // Extract/insert on a lower ranked extract strided slice op. - Value zero = rewriter.create(loc, elemType, - rewriter.getZeroAttr(elemType)); + Value zero = rewriter.create( + loc, elemType, rewriter.getZeroAttr(elemType)); Value res = rewriter.create(loc, dstType, zero); for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/AMX/AMXDialect.h" #include "mlir/Dialect/AMX/Transforms.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms.h" @@ -42,6 +43,7 @@ // Override explicitly to allow conditional dialect dependence. void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + registry.insert(); registry.insert(); if (enableArmNeon) registry.insert(); @@ -84,6 +86,7 @@ // Architecture specific augmentations. LLVMConversionTarget target(getContext()); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); diff --git a/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt b/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt @@ -8,6 +8,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRLLVMIR MLIRMemRef MLIRTransforms diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -17,6 +17,7 @@ #include "../PassDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -123,8 +124,8 @@ return Value(); Location loc = xferOp.getLoc(); - Value ivI32 = - b.create(loc, IntegerType::get(b.getContext(), 32), iv); + Value ivI32 = b.create( + loc, IntegerType::get(b.getContext(), 32), iv); return b.create(loc, xferOp.mask(), ivI32); } @@ -171,13 +172,14 @@ bindDims(xferOp.getContext(), d0, d1); Value base = xferOp.indices()[dim.getValue()]; Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); - cond = lb.create(CmpIPredicate::sgt, memrefDim, memrefIdx); + cond = lb.create(arith::CmpIPredicate::sgt, memrefDim, + memrefIdx); } // Condition check 2: Masked in? if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { if (cond) - cond = lb.create(cond, maskCond); + cond = lb.create(cond, maskCond); else cond = maskCond; } @@ -704,10 +706,10 @@ } // Loop bounds and step. - auto lb = locB.create(0); - auto ub = locB.create( + auto lb = locB.create(0); + auto ub = locB.create( castedDataType.getDimSize(castedDataType.getRank() - 1)); - auto step = locB.create(1); + auto step = locB.create(1); // TransferWriteOps that operate on tensors return the modified tensor and // require a loop state. auto loopState = Strategy::initialLoopState(xferOp); @@ -897,7 +899,7 @@ // Generate fully unrolled loop of transfer ops. Location loc = xferOp.getLoc(); for (int64_t i = 0; i < dimSize; ++i) { - Value iv = rewriter.create(loc, i); + Value iv = rewriter.create(loc, i); vec = generateInBoundsCheck( rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType), @@ -1023,7 +1025,7 @@ // Generate fully unrolled loop of transfer ops. Location loc = xferOp.getLoc(); for (int64_t i = 0; i < dimSize; ++i) { - Value iv = rewriter.create(loc, i); + Value iv = rewriter.create(loc, i); auto updatedSource = generateInBoundsCheck( rewriter, xferOp, iv, unpackedDim(xferOp), @@ -1114,8 +1116,8 @@ ValueRange loopState) { SmallVector indices; auto dim = get1dMemrefIndices(b, xferOp, iv, indices); - Value ivI32 = - b.create(loc, IntegerType::get(b.getContext(), 32), iv); + Value ivI32 = b.create( + loc, IntegerType::get(b.getContext(), 32), iv); auto vec = loopState[0]; // In case of out-of-bounds access, leave `vec` as is (was initialized with @@ -1147,8 +1149,8 @@ ValueRange /*loopState*/) { SmallVector indices; auto dim = get1dMemrefIndices(b, xferOp, iv, indices); - Value ivI32 = - b.create(loc, IntegerType::get(b.getContext(), 32), iv); + Value ivI32 = b.create( + loc, IntegerType::get(b.getContext(), 32), iv); // Nothing to do in case of out-of-bounds access. generateInBoundsCheck( @@ -1224,9 +1226,10 @@ // Loop bounds, step, state... Location loc = xferOp.getLoc(); auto vecType = xferOp.getVectorType(); - auto lb = rewriter.create(loc, 0); - auto ub = rewriter.create(loc, vecType.getDimSize(0)); - auto step = rewriter.create(loc, 1); + auto lb = rewriter.create(loc, 0); + auto ub = + rewriter.create(loc, vecType.getDimSize(0)); + auto step = rewriter.create(loc, 1); auto loopState = Strategy1d::initialLoopState(rewriter, xferOp); // Generate for loop. diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -221,7 +222,7 @@ Operation *AffineDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create(loc, type, value); + return builder.create(loc, type, value); } /// A utility function to check if a value is defined at the top level of an @@ -1887,12 +1888,11 @@ buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn) { - auto lbConst = lb.getDefiningOp(); - auto ubConst = ub.getDefiningOp(); + auto lbConst = lb.getDefiningOp(); + auto ubConst = ub.getDefiningOp(); if (lbConst && ubConst) - return buildAffineLoopFromConstants(builder, loc, lbConst.getValue(), - ubConst.getValue(), step, - bodyBuilderFn); + return buildAffineLoopFromConstants(builder, loc, lbConst.value(), + ubConst.value(), step, bodyBuilderFn); return builder.create(loc, lb, builder.getDimIdentityMap(), ub, builder.getDimIdentityMap(), step, /*iterArgs=*/llvm::None, bodyBuilderFn); diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt @@ -11,6 +11,7 @@ MLIRAffineOpsIncGen LINK_LIBS PUBLIC + MLIRArithmetic MLIRIR MLIRLoopLikeInterface MLIRMemRef diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -23,6 +23,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -199,7 +200,7 @@ void AffineDataCopyGeneration::runOnFunction() { FuncOp f = getFunction(); OpBuilder topBuilder(f.getBody()); - zeroIndex = topBuilder.create(f.getLoc(), 0); + zeroIndex = topBuilder.create(f.getLoc(), 0); // Nests that are copy-in's or copy-out's; the root AffineForOps of those // nests are stored herein. diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -18,6 +18,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -81,7 +82,7 @@ } else if (isa(op)) { // TODO: Support DMA ops. return false; - } else if (!isa(op)) { + } else if (!isa(op)) { // Register op in the set of ops that have users. opsWithUsers.insert(&op); if (isa(op)) { diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt @@ -21,6 +21,7 @@ LINK_LIBS PUBLIC MLIRAffine MLIRAffineUtils + MLIRArithmetic MLIRIR MLIRMemRef MLIRPass diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -17,6 +17,7 @@ #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -344,8 +345,8 @@ /// %A = alloc (%M, %N) : memref /// %B = alloc (%M, %N) : memref /// %C = alloc (%M, %N) : memref -/// %f1 = constant 1.0 : f32 -/// %f2 = constant 2.0 : f32 +/// %f1 = arith.constant 1.0 : f32 +/// %f2 = arith.constant 2.0 : f32 /// affine.for %i0 = 0 to %M { /// affine.for %i1 = 0 to %N { /// // non-scoped %f1 @@ -362,18 +363,18 @@ /// affine.for %i5 = 0 to %N { /// %a5 = affine.load %A[%i4, %i5] : memref /// %b5 = affine.load %B[%i4, %i5] : memref -/// %s5 = addf %a5, %b5 : f32 +/// %s5 = arith.addf %a5, %b5 : f32 /// // non-scoped %f1 -/// %s6 = addf %s5, %f1 : f32 +/// %s6 = arith.addf %s5, %f1 : f32 /// // non-scoped %f2 -/// %s7 = addf %s5, %f2 : f32 +/// %s7 = arith.addf %s5, %f2 : f32 /// // diamond dependency. -/// %s8 = addf %s7, %s6 : f32 +/// %s8 = arith.addf %s7, %s6 : f32 /// affine.store %s8, %C[%i4, %i5] : memref /// } /// } -/// %c7 = constant 7 : index -/// %c42 = constant 42 : index +/// %c7 = arith.constant 7 : index +/// %c42 = arith.constant 42 : index /// %res = load %C[%c7, %c42] : memref /// return %res : f32 /// } @@ -390,11 +391,11 @@ /// %0 = alloc(%arg0, %arg1) : memref /// %1 = alloc(%arg0, %arg1) : memref /// %2 = alloc(%arg0, %arg1) : memref -/// %cst = constant 1.0 : f32 -/// %cst_0 = constant 2.0 : f32 +/// %cst = arith.constant 1.0 : f32 +/// %cst_0 = arith.constant 2.0 : f32 /// affine.for %i0 = 0 to %arg0 { /// affine.for %i1 = 0 to %arg1 step 256 { -/// %cst_1 = constant dense, 1.0> : +/// %cst_1 = arith.constant dense, 1.0> : /// vector<256xf32> /// vector.transfer_write %cst_1, %0[%i0, %i1] : /// vector<256xf32>, memref @@ -402,7 +403,7 @@ /// } /// affine.for %i2 = 0 to %arg0 { /// affine.for %i3 = 0 to %arg1 step 256 { -/// %cst_2 = constant dense, 2.0> : +/// %cst_2 = arith.constant dense, 2.0> : /// vector<256xf32> /// vector.transfer_write %cst_2, %1[%i2, %i3] : /// vector<256xf32>, memref @@ -414,20 +415,20 @@ /// memref, vector<256xf32> /// %4 = vector.transfer_read %1[%i4, %i5] : /// memref, vector<256xf32> -/// %5 = addf %3, %4 : vector<256xf32> -/// %cst_3 = constant dense, 1.0> : +/// %5 = arith.addf %3, %4 : vector<256xf32> +/// %cst_3 = arith.constant dense, 1.0> : /// vector<256xf32> -/// %6 = addf %5, %cst_3 : vector<256xf32> -/// %cst_4 = constant dense, 2.0> : +/// %6 = arith.addf %5, %cst_3 : vector<256xf32> +/// %cst_4 = arith.constant dense, 2.0> : /// vector<256xf32> -/// %7 = addf %5, %cst_4 : vector<256xf32> -/// %8 = addf %7, %6 : vector<256xf32> +/// %7 = arith.addf %5, %cst_4 : vector<256xf32> +/// %8 = arith.addf %7, %6 : vector<256xf32> /// vector.transfer_write %8, %2[%i4, %i5] : /// vector<256xf32>, memref /// } /// } -/// %c7 = constant 7 : index -/// %c42 = constant 42 : index +/// %c7 = arith.constant 7 : index +/// %c42 = arith.constant 42 : index /// %9 = load %2[%c7, %c42] : memref /// return %9 : f32 /// } @@ -444,11 +445,11 @@ /// %0 = alloc(%arg0, %arg1) : memref /// %1 = alloc(%arg0, %arg1) : memref /// %2 = alloc(%arg0, %arg1) : memref -/// %cst = constant 1.0 : f32 -/// %cst_0 = constant 2.0 : f32 +/// %cst = arith.constant 1.0 : f32 +/// %cst_0 = arith.constant 2.0 : f32 /// affine.for %i0 = 0 to %arg0 step 32 { /// affine.for %i1 = 0 to %arg1 step 256 { -/// %cst_1 = constant dense, 1.0> : +/// %cst_1 = arith.constant dense, 1.0> : /// vector<32x256xf32> /// vector.transfer_write %cst_1, %0[%i0, %i1] : /// vector<32x256xf32>, memref @@ -456,7 +457,7 @@ /// } /// affine.for %i2 = 0 to %arg0 step 32 { /// affine.for %i3 = 0 to %arg1 step 256 { -/// %cst_2 = constant dense, 2.0> : +/// %cst_2 = arith.constant dense, 2.0> : /// vector<32x256xf32> /// vector.transfer_write %cst_2, %1[%i2, %i3] : /// vector<32x256xf32>, memref @@ -468,20 +469,20 @@ /// memref vector<32x256xf32> /// %4 = vector.transfer_read %1[%i4, %i5] : /// memref, vector<32x256xf32> -/// %5 = addf %3, %4 : vector<32x256xf32> -/// %cst_3 = constant dense, 1.0> : +/// %5 = arith.addf %3, %4 : vector<32x256xf32> +/// %cst_3 = arith.constant dense, 1.0> : /// vector<32x256xf32> -/// %6 = addf %5, %cst_3 : vector<32x256xf32> -/// %cst_4 = constant dense, 2.0> : +/// %6 = arith.addf %5, %cst_3 : vector<32x256xf32> +/// %cst_4 = arith.constant dense, 2.0> : /// vector<32x256xf32> -/// %7 = addf %5, %cst_4 : vector<32x256xf32> -/// %8 = addf %7, %6 : vector<32x256xf32> +/// %7 = arith.addf %5, %cst_4 : vector<32x256xf32> +/// %8 = arith.addf %7, %6 : vector<32x256xf32> /// vector.transfer_write %8, %2[%i4, %i5] : /// vector<32x256xf32>, memref /// } /// } -/// %c7 = constant 7 : index -/// %c42 = constant 42 : index +/// %c7 = arith.constant 7 : index +/// %c42 = arith.constant 42 : index /// %9 = load %2[%c7, %c42] : memref /// return %9 : f32 /// } @@ -511,11 +512,11 @@ /// Consider the following example: /// ```mlir /// func @vecred(%in: memref<512xf32>) -> f32 { -/// %cst = constant 0.000000e+00 : f32 +/// %cst = arith.constant 0.000000e+00 : f32 /// %sum = affine.for %i = 0 to 500 iter_args(%part_sum = %cst) -> (f32) { /// %ld = affine.load %in[%i] : memref<512xf32> /// %cos = math.cos %ld : f32 -/// %add = addf %part_sum, %cos : f32 +/// %add = arith.addf %part_sum, %cos : f32 /// affine.yield %add : f32 /// } /// return %sum : f32 @@ -531,18 +532,18 @@ /// ```mlir /// #map = affine_map<(d0) -> (-d0 + 500)> /// func @vecred(%arg0: memref<512xf32>) -> f32 { -/// %cst = constant 0.000000e+00 : f32 -/// %cst_0 = constant dense<0.000000e+00> : vector<128xf32> +/// %cst = arith.constant 0.000000e+00 : f32 +/// %cst_0 = arith.constant dense<0.000000e+00> : vector<128xf32> /// %0 = affine.for %arg1 = 0 to 500 step 128 iter_args(%arg2 = %cst_0) /// -> (vector<128xf32>) { /// // %2 is the number of iterations left in the original loop. /// %2 = affine.apply #map(%arg1) /// %3 = vector.create_mask %2 : vector<128xi1> -/// %cst_1 = constant 0.000000e+00 : f32 +/// %cst_1 = arith.constant 0.000000e+00 : f32 /// %4 = vector.transfer_read %arg0[%arg1], %cst_1 : /// memref<512xf32>, vector<128xf32> /// %5 = math.cos %4 : vector<128xf32> -/// %6 = addf %arg2, %5 : vector<128xf32> +/// %6 = arith.addf %arg2, %5 : vector<128xf32> /// // We filter out the effect of last 12 elements using the mask. /// %7 = select %3, %6, %arg2 : vector<128xi1>, vector<128xf32> /// affine.yield %7 : vector<128xf32> @@ -674,8 +675,8 @@ /// the vectorized operations. /// /// Example: - /// * 'replaced': %0 = addf %1, %2 : f32 - /// * 'replacement': %0 = addf %1, %2 : vector<128xf32> + /// * 'replaced': %0 = arith.addf %1, %2 : f32 + /// * 'replacement': %0 = arith.addf %1, %2 : vector<128xf32> void registerOpVectorReplacement(Operation *replaced, Operation *replacement); /// Registers the vector replacement of a scalar value. The replacement @@ -772,8 +773,8 @@ /// the vectorized operations. /// /// Example: -/// * 'replaced': %0 = addf %1, %2 : f32 -/// * 'replacement': %0 = addf %1, %2 : vector<128xf32> +/// * 'replaced': %0 = arith.addf %1, %2 : f32 +/// * 'replacement': %0 = arith.addf %1, %2 : vector<128xf32> void VectorizationState::registerOpVectorReplacement(Operation *replaced, Operation *replacement) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op:\n"); @@ -941,14 +942,14 @@ /// Tries to transform a scalar constant into a vector constant. Returns the /// vector constant if the scalar type is valid vector element type. Returns /// nullptr, otherwise. -static ConstantOp vectorizeConstant(ConstantOp constOp, - VectorizationState &state) { +static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp, + VectorizationState &state) { Type scalarTy = constOp.getType(); if (!VectorType::isValidElementType(scalarTy)) return nullptr; auto vecTy = getVectorType(scalarTy, state.strategy); - auto vecAttr = DenseElementsAttr::get(vecTy, constOp.getValue()); + auto vecAttr = DenseElementsAttr::get(vecTy, constOp.value()); OpBuilder::InsertionGuard guard(state.builder); Operation *parentOp = state.builder.getInsertionBlock()->getParentOp(); @@ -959,7 +960,8 @@ isa(parentOp) && "Expected a vectorized for op"); auto vecForOp = cast(parentOp); state.builder.setInsertionPointToStart(vecForOp.getBody()); - auto newConstOp = state.builder.create(constOp.getLoc(), vecAttr); + auto newConstOp = + state.builder.create(constOp.getLoc(), vecAttr); // Register vector replacement for future uses in the scope. state.registerOpVectorReplacement(constOp, newConstOp); @@ -969,9 +971,9 @@ /// Creates a constant vector filled with the neutral elements of the given /// reduction. The scalar type of vector elements will be taken from /// `oldOperand`. -static ConstantOp createInitialVector(AtomicRMWKind reductionKind, - Value oldOperand, - VectorizationState &state) { +static arith::ConstantOp createInitialVector(AtomicRMWKind reductionKind, + Value oldOperand, + VectorizationState &state) { Type scalarTy = oldOperand.getType(); if (!VectorType::isValidElementType(scalarTy)) return nullptr; @@ -981,7 +983,7 @@ auto vecTy = getVectorType(scalarTy, state.strategy); auto vecAttr = DenseElementsAttr::get(vecTy, valueAttr); auto newConstOp = - state.builder.create(oldOperand.getLoc(), vecAttr); + state.builder.create(oldOperand.getLoc(), vecAttr); return newConstOp; } @@ -1128,8 +1130,8 @@ "Vector op not found in replacement map"); // Vectorize constant. - if (auto constOp = operand.getDefiningOp()) { - ConstantOp vecConstant = vectorizeConstant(constOp, state); + if (auto constOp = operand.getDefiningOp()) { + auto vecConstant = vectorizeConstant(constOp, state); LLVM_DEBUG(dbgs() << "-> constant: " << vecConstant); return vecConstant.getResult(); } @@ -1250,7 +1252,7 @@ return false; Attribute valueAttr = getIdentityValueAttr(reductionKind, scalarTy, state.builder, value.getLoc()); - if (auto constOp = dyn_cast_or_null(value.getDefiningOp())) + if (auto constOp = dyn_cast_or_null(value.getDefiningOp())) return constOp.value() == valueAttr; return false; } @@ -1425,7 +1427,7 @@ // being added to the accumulator by inserting `select` operations, for // example: // - // %res = addf %acc, %val : vector<128xf32> + // %res = arith.addf %acc, %val : vector<128xf32> // %res_masked = select %mask, %res, %acc : vector<128xi1>, vector<128xf32> // affine.yield %res_masked : vector<128xf32> // @@ -1472,7 +1474,7 @@ return vectorizeAffineForOp(forOp, state); if (auto yieldOp = dyn_cast(op)) return vectorizeAffineYieldOp(yieldOp, state); - if (auto constant = dyn_cast(op)) + if (auto constant = dyn_cast(op)) return vectorizeConstant(constant, state); // Other ops with regions are not supported. diff --git a/mlir/lib/Dialect/Arithmetic/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/IR/Builders.h" #include "mlir/Transforms/InliningUtils.h" using namespace mlir; @@ -28,10 +29,18 @@ }; } // end anonymous namespace -void mlir::arith::ArithmeticDialect::initialize() { +void arith::ArithmeticDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" >(); addInterfaces(); } + +/// Materialize an integer or floating point constant. +Operation *arith::ArithmeticDialect::materializeConstant(OpBuilder &builder, + Attribute value, + Type type, + Location loc) { + return builder.create(loc, value, type); +} diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -75,6 +75,112 @@ #include "ArithmeticCanonicalization.inc" } // end anonymous namespace +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +void arith::ConstantOp::getAsmResultNames( + function_ref setNameFn) { + auto type = getType(); + if (auto intCst = value().dyn_cast()) { + auto intType = type.dyn_cast(); + + // Sugar i1 constants with 'true' and 'false'. + if (intType && intType.getWidth() == 1) + return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); + + // Otherwise, build a compex name with the value and type. + SmallString<32> specialNameBuffer; + llvm::raw_svector_ostream specialName(specialNameBuffer); + specialName << 'c' << intCst.getInt(); + if (intType) + specialName << '_' << type; + setNameFn(getResult(), specialName.str()); + } else { + setNameFn(getResult(), "cst"); + } +} + +/// TODO: disallow arith.constant to return anything other than signless integer +/// or float like. +static LogicalResult verify(arith::ConstantOp op) { + auto type = op.getType(); + // The value's type must match the return type. + if (op.value().getType() != type) { + return op.emitOpError() << "value type " << op.value().getType() + << " must match return type: " << type; + } + // Integer values must be signless. + if (type.isa() && !type.cast().isSignless()) + return op.emitOpError("integer return type must be signless"); + // Any float or elements attribute are acceptable. + if (!op.value().isa()) { + return op.emitOpError( + "value must be an integer, float, or elements attribute"); + } + return success(); +} + +bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { + // The value's type must be the same as the provided type. + if (value.getType() != type) + return false; + // Integer values must be signless. + if (type.isa() && !type.cast().isSignless()) + return false; + // Integer, float, and element attributes are buildable. + return value.isa(); +} + +OpFoldResult arith::ConstantOp::fold(ArrayRef operands) { + return value(); +} + +void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, + int64_t value, unsigned width) { + auto type = builder.getIntegerType(width); + arith::ConstantOp::build(builder, result, type, + builder.getIntegerAttr(type, value)); +} + +void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, + int64_t value, Type type) { + assert(type.isSignlessInteger() && + "ConstantIntOp can only have signless integer type values"); + arith::ConstantOp::build(builder, result, type, + builder.getIntegerAttr(type, value)); +} + +bool arith::ConstantIntOp::classof(Operation *op) { + if (auto constOp = dyn_cast_or_null(op)) + return constOp.getType().isSignlessInteger(); + return false; +} + +void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, + const APFloat &value, FloatType type) { + arith::ConstantOp::build(builder, result, type, + builder.getFloatAttr(type, value)); +} + +bool arith::ConstantFloatOp::classof(Operation *op) { + if (auto constOp = dyn_cast_or_null(op)) + return constOp.getType().isa(); + return false; +} + +void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, + int64_t value) { + arith::ConstantOp::build(builder, result, builder.getIndexType(), + builder.getIndexAttr(value)); +} + +bool arith::ConstantIndexOp::classof(Operation *op) { + if (auto constOp = dyn_cast_or_null(op)) + return constOp.getType().isIndex(); + return false; +} + //===----------------------------------------------------------------------===// // AddIOp //===----------------------------------------------------------------------===// @@ -377,6 +483,10 @@ /// or(x, x) -> x if (lhs() == rhs()) return rhs(); + /// or(x, ) -> + if (auto rhsAttr = operands[1].dyn_cast_or_null()) + if (rhsAttr.getValue().isAllOnes()) + return rhsAttr; return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a | b; }); @@ -439,6 +549,49 @@ operands, [](APFloat a, APFloat b) { return a / b; }); } +//===----------------------------------------------------------------------===// +// Utility functions for verifying cast ops +//===----------------------------------------------------------------------===// + +template +using type_list = std::tuple *; + +/// Returns a non-null type only if the provided type is one of the allowed +/// types or one of the allowed shaped types of the allowed types. Returns the +/// element type if a valid shaped type is provided. +template +static Type getUnderlyingType(Type type, type_list, + type_list) { + if (type.isa() && !type.isa()) + return {}; + + auto underlyingType = getElementTypeOrSelf(type); + if (!underlyingType.isa()) + return {}; + + return underlyingType; +} + +/// Get allowed underlying types for vectors and tensors. +template +static Type getTypeIfLike(Type type) { + return getUnderlyingType(type, type_list(), + type_list()); +} + +/// Get allowed underlying types for vectors, tensors, and memrefs. +template +static Type getTypeIfLikeOrMemRef(Type type) { + return getUnderlyingType(type, + type_list(), + type_list()); +} + +static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { + return inputs.size() == 1 && outputs.size() == 1 && + succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); +} + //===----------------------------------------------------------------------===// // Verifiers for integer and floating point extension/truncation ops //===----------------------------------------------------------------------===// @@ -469,6 +622,21 @@ return success(); } +/// Validate a cast that changes the width of a type. +template