diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -1662,7 +1662,7 @@ // Check that the body defines as single block argument for the induction // variable. auto *body = getBody(); - if (!body->getArgument(1).getType().isInteger(1)) + if (!body->getArgument(1).getType().isBool()) return emitOpError( "expected body second argument to be an index argument for " "the induction variable"); diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -154,6 +154,8 @@ /// Return true if this is an integer (of any signedness), index, or float /// type. bool isIntOrIndexOrFloat() const; + /// Return true if this is an bool type. + bool isBool() const; /// Print the current type. void print(raw_ostream &os) const; diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -268,10 +268,10 @@ /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { - if (type.isInteger(1)) + if (type.isBool()) return true; if (auto vecType = type.dyn_cast()) - return vecType.getElementType().isInteger(1); + return vecType.getElementType().isBool(); return false; } @@ -332,7 +332,7 @@ return failure(); elements.push_back(dstAttr); } - } else if (srcElemType.isInteger(1)) { + } else if (srcElemType.isBool()) { return failure(); } else { for (IntegerAttr srcAttr : dstElementsAttr.getValues()) { @@ -403,7 +403,7 @@ } // Bool type. - if (srcType.isInteger(1)) { + if (srcType.isBool()) { // arith.constant can use 0/1 instead of true/false for i1 values. We need // to handle that here. auto dstAttr = convertBoolAttr(cstAttr, rewriter); diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -134,7 +134,7 @@ /// Casts the given `srcInt` into a boolean value. static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { - if (srcInt.getType().isInteger(1)) + if (srcInt.getType().isBool()) return srcInt; auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder); @@ -144,8 +144,8 @@ /// Casts the given `srcBool` into an integer of `dstType`. static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder) { - assert(srcBool.getType().isInteger(1)); - if (dstType.isInteger(1)) + assert(srcBool.getType().isBool()); + if (dstType.isBool()) return srcBool; Value zero = spirv::ConstantOp::getZero(dstType, loc, builder); Value one = spirv::ConstantOp::getOne(dstType, loc, builder); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -263,22 +263,22 @@ } // tosa::LogicalAnd - if (isa(op) && elementTy.isInteger(1)) + if (isa(op) && elementTy.isBool()) return rewriter.create(loc, resultTypes, args); // tosa::LogicalNot - if (isa(op) && elementTy.isInteger(1)) { + if (isa(op) && elementTy.isBool()) { auto one = rewriter.create( loc, rewriter.getIntegerAttr(elementTy, 1)); return rewriter.create(loc, resultTypes, args[0], one); } // tosa::LogicalOr - if (isa(op) && elementTy.isInteger(1)) + if (isa(op) && elementTy.isBool()) return rewriter.create(loc, resultTypes, args); // tosa::LogicalXor - if (isa(op) && elementTy.isInteger(1)) + if (isa(op) && elementTy.isBool()) return rewriter.create(loc, resultTypes, args); // tosa::PowOp @@ -455,11 +455,11 @@ mlir::None); // 1-bit integers need to be treated as signless. - if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy)) + if (srcTy.isBool() && arith::UIToFPOp::areCastCompatible(srcTy, dstTy)) return rewriter.create(loc, resultTypes, args, mlir::None); - if (srcTy.isInteger(1) && dstTy.isa() && bitExtend) + if (srcTy.isBool() && dstTy.isa() && bitExtend) return rewriter.create(loc, resultTypes, args, mlir::None); @@ -482,7 +482,7 @@ mlir::None); // Casting to boolean, floats need to only be checked as not-equal to zero. - if (srcTy.isa() && dstTy.isInteger(1)) { + if (srcTy.isa() && dstTy.isBool()) { Value zero = rewriter.create( loc, rewriter.getFloatAttr(srcTy, 0.0)); return rewriter.create(loc, arith::CmpFPredicate::UNE, @@ -520,7 +520,7 @@ // Casting to boolean, integers need to only be checked as not-equal to // zero. - if (srcTy.isa() && dstTy.isInteger(1)) { + if (srcTy.isa() && dstTy.isBool()) { Value zero = rewriter.create( loc, 0, srcTy.getIntOrFloatBitWidth()); return rewriter.create(loc, arith::CmpIPredicate::ne, @@ -700,10 +700,10 @@ return rewriter.getIntegerAttr( elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); - if (isa(op) && elementTy.isInteger(1)) + if (isa(op) && elementTy.isBool()) return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1)); - if (isa(op) && elementTy.isInteger(1)) + if (isa(op) && elementTy.isBool()) return rewriter.getIntegerAttr(elementTy, APInt::getZero(1)); if (isa(op) && elementTy.isa()) @@ -765,10 +765,10 @@ return rewriter.create(loc, predicate, args[0], args[1]); } - if (isa(op) && elementTy.isInteger(1)) + if (isa(op) && elementTy.isBool()) return rewriter.create(loc, args); - if (isa(op) && elementTy.isInteger(1)) + if (isa(op) && elementTy.isBool()) return rewriter.create(loc, args); return {}; diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -1296,7 +1296,7 @@ if (matchPattern(getRhs(), m_Zero())) { if (auto extOp = getLhs().getDefiningOp()) { - if (extOp.getOperand().getType().cast().getWidth() == 1) { + if (extOp.getOperand().getType().isBool()) { // extsi(%x : i1 -> iN) != 0 -> %x if (getPredicate() == arith::CmpIPredicate::ne) { return extOp.getOperand(); @@ -1304,7 +1304,7 @@ } } if (auto extOp = getLhs().getDefiningOp()) { - if (extOp.getOperand().getType().cast().getWidth() == 1) { + if (extOp.getOperand().getType().isBool()) { // extui(%x : i1 -> iN) != 0 -> %x if (getPredicate() == arith::CmpIPredicate::ne) { return extOp.getOperand(); @@ -1705,7 +1705,7 @@ LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override { - if (!op.getType().isInteger(1)) + if (!op.getType().isBool()) return failure(); Value falseConstant = @@ -1729,7 +1729,7 @@ LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override { // Cannot extui i1 to i1, or i1 to f32 - if (!op.getType().isa() || op.getType().isInteger(1)) + if (!op.getType().isa() || op.getType().isBool()) return failure(); // select %x, c1, %c0 => extui %arg @@ -1778,7 +1778,7 @@ return falseVal; // select %x, true, false => %x - if (getType().isInteger(1)) + if (getType().isBool()) if (matchPattern(getTrueValue(), m_One())) if (matchPattern(getFalseValue(), m_Zero())) return condition; diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1524,7 +1524,7 @@ if (!trueYield) continue; - if (!trueYield.getType().isInteger(1)) + if (!trueYield.getType().isBool()) continue; auto falseYield = falseResult.getDefiningOp(); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -33,12 +33,12 @@ return llvm::None; auto type = boolAttr.getType(); - if (type.isInteger(1)) { + if (type.isBool()) { auto attr = boolAttr.cast(); return attr.getValue(); } if (auto vecType = type.cast()) { - if (vecType.getElementType().isInteger(1)) + if (vecType.getElementType().isBool()) if (auto attr = boolAttr.dyn_cast()) return attr.getSplatValue(); } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -715,7 +715,7 @@ } // Handle the special encoding of splat of bool. - if (values.size() == 1 && values[0].getType().isInteger(1)) + if (values.size() == 1 && values[0].getType().isBool()) data[0] = data[0] ? -1 : 0; return DenseIntOrFPElementsAttr::getRaw(type, data); @@ -724,7 +724,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { assert(hasSameElementsOrSplat(type, values)); - assert(type.getElementType().isInteger(1)); + assert(type.getElementType().isBool()); std::vector buff(llvm::divideCeil(values.size(), CHAR_BIT)); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -968,7 +968,7 @@ LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { for (auto resultType : op->getResultTypes()) { auto elementType = getTensorOrVectorElementType(resultType); - bool isBoolType = elementType.isInteger(1); + bool isBoolType = elementType.isBool(); if (!isBoolType) return op->emitOpError() << "requires a bool result type"; } diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -94,3 +94,5 @@ return intType.getWidth(); return cast().getWidth(); } + +bool Type::isBool() const { return isInteger(1); } diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -601,7 +601,7 @@ assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && "unexpected token type"); if (token.isAny(Token::kw_true, Token::kw_false)) { - if (!eltTy.isInteger(1)) { + if (!eltTy.isBool()) { return p.emitError(tokenLoc) << "expected i1 type for 'true' or 'false' values"; } diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -708,7 +708,7 @@ assert(dim <= shapedType.getRank()); if (shapedType.getRank() == dim) { if (auto attr = valueAttr.dyn_cast()) { - return attr.getType().getElementType().isInteger(1) + return attr.getType().getElementType().isBool() ? prepareConstantBool(loc, attr.getValues()[index]) : prepareConstantInt(loc, attr.getValues()[index]);