diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -21,7 +21,8 @@ // Type constraint accepting standard integers, indices and wrapped LLVM integer // types. def IntLikeOrLLVMInt : TypeConstraint< - Or<[AnySignlessInteger.predicate, Index.predicate, LLVMInt.predicate]>, + Or<[AnySignlessInteger.predicate, Index.predicate, + LLVM_AnyInteger.predicate]>, "integer, index or LLVM dialect equivalent">; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -17,6 +17,10 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +//===----------------------------------------------------------------------===// +// LLVM Dialect. +//===----------------------------------------------------------------------===// + def LLVM_Dialect : Dialect { let name = "llvm"; let cppNamespace = "LLVM"; @@ -38,34 +42,108 @@ }]; } -// LLVM IR type wrapped in MLIR. +//===----------------------------------------------------------------------===// +// LLVM dialect type constraints. +//===----------------------------------------------------------------------===// + +// LLVM dialect type. def LLVM_Type : DialectType()">, "LLVM dialect type">; -// Type constraint accepting only wrapped LLVM integer types. -def LLVMInt : TypeConstraint< - And<[LLVM_Type.predicate, - CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>, - "LLVM dialect integer">; +// Type constraint accepting LLVM integer types. +def LLVM_AnyInteger : Type< + CPred<"$_self.isa<::mlir::LLVM::LLVMIntegerType>()">, + "LLVM integer type">; + +// Type constraints accepting LLVM integer type of a specific width. +class LLVM_IntBase : + Type().getBitWidth() == " + # width>]>, + "LLVM " # width # "-bit integer type">, + BuildableType< + "::mlir::LLVM::LLVMIntegerType::get($_builder.getContext(), " + # width # ")">; + +def LLVM_i1 : LLVM_IntBase<1>; +def LLVM_i8 : LLVM_IntBase<8>; +def LLVM_i32 : LLVM_IntBase<32>; -def LLVMIntBase : TypeConstraint< +// Type constraint accepting LLVM primitive types, i.e. all types except void +// and function. +def LLVM_PrimitiveType : Type< And<[LLVM_Type.predicate, - CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>, - "LLVM dialect integer">; - -// Integer type of a specific width. -class LLVMI - : Type().isIntegerTy(" # width # ")">]>, - "LLVM dialect " # width # "-bit integer">, - BuildableType< - "::mlir::LLVM::LLVMType::getIntNTy($_builder.getContext()," - # width # ")">; - -def LLVMI1 : LLVMI<1>; + CPred<"!$_self.isa<::mlir::LLVM::LLVMVoidType, " + "::mlir::LLVM::LLVMFunctionType>()">]>, + "primitive LLVM type">; + +// Type constraint accepting any LLVM floating point type. +def LLVM_AnyFloat : Type< + CPred<"$_self.isa<::mlir::LLVM::LLVMBFloatType, " + "::mlir::LLVM::LLVMHalfType, " + "::mlir::LLVM::LLVMFloatType, " + "::mlir::LLVM::LLVMDoubleType>()">, + "floating point LLVM type">; + +// Type constraint accepting any LLVM pointer type. +def LLVM_AnyPointer : Type()">, + "LLVM pointer type">; + +// Type constraint accepting LLVM pointer type with an additional constraint +// on the element type. +class LLVM_PointerTo : Type< + And<[LLVM_AnyPointer.predicate, + SubstLeaves< + "$_self", + "$_self.cast<::mlir::LLVM::LLVMPointerType>().getElementType()", + pointee.predicate>]>, + "LLVM pointer to " # pointee.description>; + +// Type constraint accepting any LLVM structure type. +def LLVM_AnyStruct : Type()">, + "LLVM structure type">; + +// Type constraint accepting opaque LLVM structure type. +def LLVM_OpaqueStruct : Type< + And<[LLVM_AnyStruct.predicate, + CPred<"$_self.cast<::mlir::LLVM::LLVMStructType>().isOpaque()">]>>; + +// Type constraint accepting any LLVM type that can be loaded or stored, i.e. a +// type that has size (not void, function or opaque struct type). +def LLVM_LoadableType : Type< + And<[LLVM_PrimitiveType.predicate, Neg]>, + "LLVM type with size">; + +// Type constraint accepting any LLVM aggregate type, i.e. structure or array. +def LLVM_AnyAggregate : Type< + CPred<"$_self.isa<::mlir::LLVM::LLVMStructType, " + "::mlir::LLVM::LLVMArrayType>()">, + "LLVM aggregate type">; + +// Type constraint accepting any LLVM non-aggregate type, i.e. not structure or +// array. +def LLVM_AnyNonAggregate : Type, + "LLVM non-aggregate type">; + +// Type constraint accepting any LLVM vector type. +def LLVM_AnyVector : Type()">, + "LLVM vector type">; + +// Type constraint accepting an LLVM vector type with an additional constraint +// on the vector element type. +class LLVM_VectorOf : Type< + And<[LLVM_AnyVector.predicate, + SubstLeaves< + "$_self", + "$_self.cast<::mlir::LLVM::LLVMVectorType>().getElementType()", + element.predicate>]>, + "LLVM vector of " # element.description>; + +// Type constraint accepting a constrained type, or a vector of such types. +class LLVM_ScalarOrVectorOf : + AnyTypeOf<[element, LLVM_VectorOf]>; // Base class for LLVM operations. Defines the interface to the llvm::IRBuilder // used to translate to LLVM IR proper. @@ -85,6 +163,10 @@ string llvmBuilder = ""; } +//===----------------------------------------------------------------------===// +// Base classes for LLVM dialect operations. +//===----------------------------------------------------------------------===// + // Base class for LLVM operations. All operations get an "llvm." prefix in // their name automatically. LLVM operations have either zero or one result, // this class is specialized below for both cases and should not be used diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -87,39 +87,50 @@ LLVM_Op; // Class for arithmetic binary operations. -class LLVM_ArithmeticOp traits = []> : +class LLVM_ArithmeticOpBase traits = []> : LLVM_OneResultOp, - Arguments<(ins LLVM_Type:$lhs, LLVM_Type:$rhs)>, + Arguments<(ins LLVM_ScalarOrVectorOf:$lhs, + LLVM_ScalarOrVectorOf:$rhs)>, LLVM_Builder<"$res = builder." # builderFunc # "($lhs, $rhs);"> { - let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; + let parser = + [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; let printer = [{ mlir::impl::printOneResultOp(this->getOperation(), p); }]; } -class LLVM_UnaryArithmeticOp traits = []> : +class LLVM_IntArithmeticOp traits = []> : + LLVM_ArithmeticOpBase; +class LLVM_FloatArithmeticOp traits = []> : + LLVM_ArithmeticOpBase; + +// Class for arithmetic unary operations. +class LLVM_UnaryArithmeticOp traits = []> : LLVM_OneResultOp, - Arguments<(ins LLVM_Type:$operand)>, + Arguments<(ins type:$operand)>, LLVM_Builder<"$res = builder." # builderFunc # "($operand);"> { - let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; + let parser = + [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; let printer = [{ mlir::impl::printOneResultOp(this->getOperation(), p); }]; } // Integer binary operations. -def LLVM_AddOp : LLVM_ArithmeticOp<"add", "CreateAdd", [Commutative]>; -def LLVM_SubOp : LLVM_ArithmeticOp<"sub", "CreateSub">; -def LLVM_MulOp : LLVM_ArithmeticOp<"mul", "CreateMul", [Commutative]>; -def LLVM_UDivOp : LLVM_ArithmeticOp<"udiv", "CreateUDiv">; -def LLVM_SDivOp : LLVM_ArithmeticOp<"sdiv", "CreateSDiv">; -def LLVM_URemOp : LLVM_ArithmeticOp<"urem", "CreateURem">; -def LLVM_SRemOp : LLVM_ArithmeticOp<"srem", "CreateSRem">; -def LLVM_AndOp : LLVM_ArithmeticOp<"and", "CreateAnd">; -def LLVM_OrOp : LLVM_ArithmeticOp<"or", "CreateOr">; -def LLVM_XOrOp : LLVM_ArithmeticOp<"xor", "CreateXor">; -def LLVM_ShlOp : LLVM_ArithmeticOp<"shl", "CreateShl">; -def LLVM_LShrOp : LLVM_ArithmeticOp<"lshr", "CreateLShr">; -def LLVM_AShrOp : LLVM_ArithmeticOp<"ashr", "CreateAShr">; +def LLVM_AddOp : LLVM_IntArithmeticOp<"add", "CreateAdd", [Commutative]>; +def LLVM_SubOp : LLVM_IntArithmeticOp<"sub", "CreateSub">; +def LLVM_MulOp : LLVM_IntArithmeticOp<"mul", "CreateMul", [Commutative]>; +def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "CreateUDiv">; +def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "CreateSDiv">; +def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "CreateURem">; +def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "CreateSRem">; +def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "CreateAnd">; +def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "CreateOr">; +def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "CreateXor">; +def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "CreateShl">; +def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "CreateLShr">; +def LLVM_AShrOp : LLVM_IntArithmeticOp<"ashr", "CreateAShr">; // Predicate for integer comparisons. def ICmpPredicateEQ : I64EnumAttrCase<"eq", 0>; @@ -143,8 +154,9 @@ // Other integer operations. def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>, - Arguments<(ins ICmpPredicate:$predicate, LLVM_Type:$lhs, - LLVM_Type:$rhs)> { + Arguments<(ins ICmpPredicate:$predicate, + LLVM_ScalarOrVectorOf:$lhs, + LLVM_ScalarOrVectorOf:$rhs)> { let llvmBuilder = [{ $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; @@ -189,8 +201,9 @@ // Other integer operations. def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]>, - Arguments<(ins FCmpPredicate:$predicate, LLVM_Type:$lhs, - LLVM_Type:$rhs)> { + Arguments<(ins FCmpPredicate:$predicate, + LLVM_ScalarOrVectorOf:$lhs, + LLVM_ScalarOrVectorOf:$rhs)> { let llvmBuilder = [{ $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; @@ -205,12 +218,13 @@ } // Floating point binary operations. -def LLVM_FAddOp : LLVM_ArithmeticOp<"fadd", "CreateFAdd">; -def LLVM_FSubOp : LLVM_ArithmeticOp<"fsub", "CreateFSub">; -def LLVM_FMulOp : LLVM_ArithmeticOp<"fmul", "CreateFMul">; -def LLVM_FDivOp : LLVM_ArithmeticOp<"fdiv", "CreateFDiv">; -def LLVM_FRemOp : LLVM_ArithmeticOp<"frem", "CreateFRem">; -def LLVM_FNegOp : LLVM_UnaryArithmeticOp<"fneg", "CreateFNeg">; +def LLVM_FAddOp : LLVM_FloatArithmeticOp<"fadd", "CreateFAdd">; +def LLVM_FSubOp : LLVM_FloatArithmeticOp<"fsub", "CreateFSub">; +def LLVM_FMulOp : LLVM_FloatArithmeticOp<"fmul", "CreateFMul">; +def LLVM_FDivOp : LLVM_FloatArithmeticOp<"fdiv", "CreateFDiv">; +def LLVM_FRemOp : LLVM_FloatArithmeticOp<"frem", "CreateFRem">; +def LLVM_FNegOp : LLVM_UnaryArithmeticOp, + "fneg", "CreateFNeg">; // Common code definition that is used to verify and set the alignment attribute // of LLVM ops that accept such an attribute. @@ -241,7 +255,8 @@ def LLVM_AllocaOp : MemoryOpWithAlignmentBase, LLVM_OneResultOp<"alloca">, - Arguments<(ins LLVM_Type:$arraySize, OptionalAttr:$alignment)> { + Arguments<(ins LLVM_AnyInteger:$arraySize, + OptionalAttr:$alignment)> { string llvmBuilder = [{ auto *inst = builder.CreateAlloca( $_resultType->getPointerElementType(), $arraySize); @@ -259,8 +274,11 @@ let parser = [{ return parseAllocaOp(parser, result); }]; let printer = [{ printAllocaOp(p, *this); }]; } + def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$base, Variadic:$indices)>, + Arguments<(ins LLVM_ScalarOrVectorOf:$base, + Variadic>:$indices)>, LLVM_Builder<"$res = builder.CreateGEP($base, $indices);"> { let assemblyFormat = [{ $base `[` $indices `]` attr-dict `:` functional-type(operands, results) @@ -269,7 +287,7 @@ def LLVM_LoadOp : MemoryOpWithAlignmentAndAttributes, LLVM_OneResultOp<"load">, - Arguments<(ins LLVM_Type:$addr, + Arguments<(ins LLVM_PointerTo:$addr, OptionalAttr:$alignment, UnitAttr:$volatile_, UnitAttr:$nontemporal)> { @@ -296,8 +314,8 @@ def LLVM_StoreOp : MemoryOpWithAlignmentAndAttributes, LLVM_ZeroResultOp<"store">, - Arguments<(ins LLVM_Type:$value, - LLVM_Type:$addr, + Arguments<(ins LLVM_LoadableType:$value, + LLVM_PointerTo:$addr, OptionalAttr:$alignment, UnitAttr:$volatile_, UnitAttr:$nontemporal)> { @@ -314,28 +332,41 @@ } // Casts. -class LLVM_CastOp traits = []> : LLVM_OneResultOp, - Arguments<(ins LLVM_Type:$arg)>, + Arguments<(ins type:$arg)>, LLVM_Builder<"$res = builder." # builderFunc # "($arg, $_resultType);"> { let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; } -def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast">; -def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "CreateAddrSpaceCast">; -def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr">; -def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt">; -def LLVM_SExtOp : LLVM_CastOp<"sext", "CreateSExt">; -def LLVM_ZExtOp : LLVM_CastOp<"zext", "CreateZExt">; -def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc">; -def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "CreateSIToFP">; -def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "CreateUIToFP">; -def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "CreateFPToSI">; -def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "CreateFPToUI">; -def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt">; -def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc">; +def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast", + LLVM_AnyNonAggregate>; +def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "CreateAddrSpaceCast", + LLVM_ScalarOrVectorOf>; +def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr", + LLVM_ScalarOrVectorOf>; +def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt", + LLVM_ScalarOrVectorOf>; +def LLVM_SExtOp : LLVM_CastOp<"sext", "CreateSExt", + LLVM_ScalarOrVectorOf>; +def LLVM_ZExtOp : LLVM_CastOp<"zext", "CreateZExt", + LLVM_ScalarOrVectorOf>; +def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc", + LLVM_ScalarOrVectorOf>; +def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "CreateSIToFP", + LLVM_ScalarOrVectorOf>; +def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "CreateUIToFP", + LLVM_ScalarOrVectorOf>; +def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "CreateFPToSI", + LLVM_ScalarOrVectorOf>; +def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "CreateFPToUI", + LLVM_ScalarOrVectorOf>; +def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt", + LLVM_ScalarOrVectorOf>; +def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc", + LLVM_ScalarOrVectorOf>; // Call-related operations. def LLVM_InvokeOp : LLVM_Op<"invoke", [ @@ -404,8 +435,8 @@ let printer = [{ printCallOp(p, *this); }]; } def LLVM_ExtractElementOp : LLVM_OneResultOp<"extractelement", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$vector, - LLVM_Type:$position)> { + Arguments<(ins LLVM_AnyVector:$vector, + LLVM_AnyInteger:$position)> { string llvmBuilder = [{ $res = builder.CreateExtractElement($vector, $position); }]; @@ -416,8 +447,8 @@ let printer = [{ printExtractElementOp(p, *this); }]; } def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$container, - ArrayAttr:$position)> { + Arguments<(ins LLVM_AnyAggregate:$container, + ArrayAttr:$position)> { string llvmBuilder = [{ $res = builder.CreateExtractValue($container, extractPosition($position)); }]; @@ -425,8 +456,9 @@ let printer = [{ printExtractValueOp(p, *this); }]; } def LLVM_InsertElementOp : LLVM_OneResultOp<"insertelement", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$vector, LLVM_Type:$value, - LLVM_Type:$position)> { + Arguments<(ins LLVM_AnyVector:$vector, + LLVM_PrimitiveType:$value, + LLVM_AnyInteger:$position)> { string llvmBuilder = [{ $res = builder.CreateInsertElement($vector, $value, $position); }]; @@ -434,8 +466,9 @@ let printer = [{ printInsertElementOp(p, *this); }]; } def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$container, LLVM_Type:$value, - ArrayAttr:$position)> { + Arguments<(ins LLVM_AnyAggregate:$container, + LLVM_PrimitiveType:$value, + ArrayAttr:$position)> { string llvmBuilder = [{ $res = builder.CreateInsertValue($container, $value, extractPosition($position)); @@ -451,7 +484,7 @@ } def LLVM_ShuffleVectorOp : LLVM_OneResultOp<"shufflevector", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$v1, LLVM_Type:$v2, ArrayAttr:$mask)> { + Arguments<(ins LLVM_AnyVector:$v1, LLVM_AnyVector:$v2, ArrayAttr:$mask)> { string llvmBuilder = [{ SmallVector position = extractPosition($mask); SmallVector mask(position.begin(), position.end()); @@ -478,8 +511,9 @@ def LLVM_SelectOp : LLVM_OneResultOp<"select", [NoSideEffect, AllTypesMatch<["trueValue", "falseValue", "res"]>]>, - Arguments<(ins LLVM_Type:$condition, LLVM_Type:$trueValue, - LLVM_Type:$falseValue)>, + Arguments<(ins LLVM_ScalarOrVectorOf:$condition, + LLVM_Type:$trueValue, + LLVM_Type:$falseValue)>, LLVM_Builder< "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> { let builders = [OpBuilder< @@ -508,7 +542,7 @@ def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", [AttrSizedOperandSegments, DeclareOpInterfaceMethods, NoSideEffect]> { - let arguments = (ins LLVMI1:$condition, + let arguments = (ins LLVM_i1:$condition, Variadic:$trueDestOperands, Variadic:$falseDestOperands, OptionalAttr:$branch_weights); @@ -1090,9 +1124,11 @@ let cppNamespace = "::mlir::LLVM"; } +def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyInteger]>; + def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw">, - Arguments<(ins AtomicBinOp:$bin_op, LLVM_Type:$ptr, LLVM_Type:$val, - AtomicOrdering:$ordering)>, + Arguments<(ins AtomicBinOp:$bin_op, LLVM_PointerTo:$ptr, + LLVM_AtomicRMWType:$val, AtomicOrdering:$ordering)>, Results<(outs LLVM_Type:$res)> { let llvmBuilder = [{ $res = builder.CreateAtomicRMW(getLLVMAtomicBinOp($bin_op), $ptr, $val, @@ -1103,8 +1139,11 @@ let verifier = "return ::verify(*this);"; } +def LLVM_AtomicCmpXchgType : AnyTypeOf<[LLVM_AnyInteger, LLVM_AnyPointer]>; + def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg">, - Arguments<(ins LLVM_Type:$ptr, LLVM_Type:$cmp, LLVM_Type:$val, + Arguments<(ins LLVM_PointerTo:$ptr, + LLVM_AtomicCmpXchgType:$cmp, LLVM_AtomicCmpXchgType:$val, AtomicOrdering:$success_ordering, AtomicOrdering:$failure_ordering)>, Results<(outs LLVM_Type:$res)> { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1533,8 +1533,6 @@ static LogicalResult verify(AtomicRMWOp op) { auto ptrType = op.ptr().getType().cast(); - if (!ptrType.isPointerTy()) - return op.emitOpError("expected LLVM IR pointer type for operand #0"); auto valType = op.val().getType().cast(); if (valType != ptrType.getPointerElementTy()) return op.emitOpError("expected LLVM IR element type for operand #0 to " diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -440,7 +440,8 @@ bool LLVMStructType::isPacked() { return getImpl()->isPacked(); } bool LLVMStructType::isIdentified() { return getImpl()->isIdentified(); } bool LLVMStructType::isOpaque() { - return getImpl()->isOpaque() || !getImpl()->isInitialized(); + return getImpl()->isIdentified() && + (getImpl()->isOpaque() || !getImpl()->isInitialized()); } bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); } StringRef LLVMStructType::getName() { return getImpl()->getIdentifier(); } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -394,7 +394,7 @@ // CHECK-LABEL: @atomicrmw_expected_ptr func @atomicrmw_expected_ptr(%f32 : !llvm.float) { - // expected-error@+1 {{expected LLVM IR pointer type for operand #0}} + // expected-error@+1 {{operand #0 must be LLVM pointer to floating point LLVM type or LLVM integer type}} %0 = "llvm.atomicrmw"(%f32, %f32) {bin_op=11, ordering=1} : (!llvm.float, !llvm.float) -> !llvm.float llvm.return } @@ -448,7 +448,7 @@ // CHECK-LABEL: @cmpxchg_expected_ptr func @cmpxchg_expected_ptr(%f32_ptr : !llvm.ptr, %f32 : !llvm.float) { - // expected-error@+1 {{expected LLVM IR pointer type for operand #0}} + // expected-error@+1 {{op operand #0 must be LLVM pointer to LLVM integer type or LLVM pointer type}} %0 = "llvm.cmpxchg"(%f32, %f32, %f32) {success_ordering=2,failure_ordering=2} : (!llvm.float, !llvm.float, !llvm.float) -> !llvm.struct<(float, i1)> llvm.return }