diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -94,9 +94,8 @@ LLVM_Builder<"$res = builder." # builderFunc # "($lhs, $rhs);"> { let arguments = (ins LLVM_ScalarOrVectorOf:$lhs, LLVM_ScalarOrVectorOf:$rhs); - let parser = - [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; - let printer = [{ mlir::impl::printOneResultOp(this->getOperation(), p); }]; + let results = (outs LLVM_ScalarOrVectorOf:$res); + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($res)"; } class LLVM_IntArithmeticOp traits = []> : @@ -112,9 +111,8 @@ !listconcat([NoSideEffect, SameOperandsAndResultType], traits)>, LLVM_Builder<"$res = builder." # builderFunc # "($operand);"> { let arguments = (ins type:$operand); - let parser = - [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; - let printer = [{ mlir::impl::printOneResultOp(this->getOperation(), p); }]; + let results = (outs type:$res); + let assemblyFormat = "$operand attr-dict `:` type($res)"; } // Integer binary operations. @@ -157,6 +155,7 @@ let arguments = (ins ICmpPredicate:$predicate, LLVM_ScalarOrVectorOf:$lhs, LLVM_ScalarOrVectorOf:$rhs); + let results = (outs LLVM_ScalarOrVectorOf:$res); let llvmBuilder = [{ $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; @@ -204,6 +203,7 @@ let arguments = (ins FCmpPredicate:$predicate, LLVM_ScalarOrVectorOf:$lhs, LLVM_ScalarOrVectorOf:$rhs); + let results = (outs LLVM_ScalarOrVectorOf:$res); let llvmBuilder = [{ $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; @@ -257,6 +257,7 @@ LLVM_OneResultOp<"alloca"> { let arguments = (ins LLVM_AnyInteger:$arraySize, OptionalAttr:$alignment); + let results = (outs LLVM_AnyPointer:$res); string llvmBuilder = [{ auto *inst = builder.CreateAlloca( $_resultType->getPointerElementType(), $arraySize); @@ -280,6 +281,7 @@ LLVM_Builder<"$res = builder.CreateGEP($base, $indices);"> { let arguments = (ins LLVM_ScalarOrVectorOf:$base, Variadic>:$indices); + let results = (outs LLVM_ScalarOrVectorOf:$res); let assemblyFormat = [{ $base `[` $indices `]` attr-dict `:` functional-type(operands, results) }]; @@ -291,6 +293,7 @@ let arguments = (ins LLVM_PointerTo:$addr, OptionalAttr:$alignment, UnitAttr:$volatile_, UnitAttr:$nontemporal); + let results = (outs LLVM_Type:$res); string llvmBuilder = [{ auto *inst = builder.CreateLoad($addr, $volatile_); }] # setAlignmentCode # setNonTemporalMetadataCode # [{ @@ -330,52 +333,64 @@ // Casts. class LLVM_CastOp traits = []> : + Type resultType, list traits = []> : LLVM_OneResultOp, LLVM_Builder<"$res = builder." # builderFunc # "($arg, $_resultType);"> { let arguments = (ins type:$arg); + let results = (outs resultType:$res); let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; } def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast", - LLVM_AnyNonAggregate>; + LLVM_AnyNonAggregate, LLVM_AnyNonAggregate>; def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "CreateAddrSpaceCast", + LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr", - LLVM_ScalarOrVectorOf>; -def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt", + LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; +def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt", + LLVM_ScalarOrVectorOf, + LLVM_ScalarOrVectorOf>; def LLVM_SExtOp : LLVM_CastOp<"sext", "CreateSExt", + LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; def LLVM_ZExtOp : LLVM_CastOp<"zext", "CreateZExt", + LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc", + LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "CreateSIToFP", - LLVM_ScalarOrVectorOf>; + LLVM_ScalarOrVectorOf, + LLVM_ScalarOrVectorOf>; def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "CreateUIToFP", - LLVM_ScalarOrVectorOf>; -def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "CreateFPToSI", + LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; +def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "CreateFPToSI", + LLVM_ScalarOrVectorOf, + LLVM_ScalarOrVectorOf>; def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "CreateFPToUI", - LLVM_ScalarOrVectorOf>; + LLVM_ScalarOrVectorOf, + LLVM_ScalarOrVectorOf>; def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt", + LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc", + LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; // Call-related operations. def LLVM_InvokeOp : LLVM_Op<"invoke", [ AttrSizedOperandSegments, DeclareOpInterfaceMethods, - Terminator - ]>, - Results<(outs Variadic)> { + Terminator]> { let arguments = (ins OptionalAttr:$callee, Variadic:$operands, Variadic:$normalDestOperands, Variadic:$unwindDestOperands); + let results = (outs Variadic); let successors = (successor AnySuccessor:$normalDest, AnySuccessor:$unwindDest); @@ -400,15 +415,16 @@ def LLVM_LandingpadOp : LLVM_OneResultOp<"landingpad"> { let arguments = (ins UnitAttr:$cleanup, Variadic); + let results = (outs LLVM_Type:$res); let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseLandingpadOp(parser, result); }]; let printer = [{ printLandingpadOp(p, *this); }]; } -def LLVM_CallOp : LLVM_Op<"call">, - Results<(outs Variadic)> { +def LLVM_CallOp : LLVM_Op<"call"> { let arguments = (ins OptionalAttr:$callee, Variadic); + let results = (outs Variadic); let builders = [ OpBuilderDAG<(ins "LLVMFuncOp":$func, "ValueRange":$operands, CArg<"ArrayRef", "{}">:$attributes), @@ -426,6 +442,7 @@ } def LLVM_ExtractElementOp : LLVM_OneResultOp<"extractelement", [NoSideEffect]> { let arguments = (ins LLVM_AnyVector:$vector, LLVM_AnyInteger:$position); + let results = (outs LLVM_Type:$res); string llvmBuilder = [{ $res = builder.CreateExtractElement($vector, $position); }]; @@ -437,6 +454,7 @@ } def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]> { let arguments = (ins LLVM_AnyAggregate:$container, ArrayAttr:$position); + let results = (outs LLVM_Type:$res); string llvmBuilder = [{ $res = builder.CreateExtractValue($container, extractPosition($position)); }]; @@ -446,6 +464,7 @@ def LLVM_InsertElementOp : LLVM_OneResultOp<"insertelement", [NoSideEffect]> { let arguments = (ins LLVM_AnyVector:$vector, LLVM_PrimitiveType:$value, LLVM_AnyInteger:$position); + let results = (outs LLVM_AnyVector:$res); string llvmBuilder = [{ $res = builder.CreateInsertElement($vector, $value, $position); }]; @@ -455,6 +474,7 @@ def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]> { let arguments = (ins LLVM_AnyAggregate:$container, LLVM_PrimitiveType:$value, ArrayAttr:$position); + let results = (outs LLVM_AnyAggregate:$res); string llvmBuilder = [{ $res = builder.CreateInsertValue($container, $value, extractPosition($position)); @@ -469,6 +489,7 @@ } def LLVM_ShuffleVectorOp : LLVM_OneResultOp<"shufflevector", [NoSideEffect]> { let arguments = (ins LLVM_AnyVector:$v1, LLVM_AnyVector:$v2, ArrayAttr:$mask); + let results = (outs LLVM_AnyVector:$res); string llvmBuilder = [{ SmallVector position = extractPosition($mask); SmallVector mask(position.begin(), position.end()); @@ -499,6 +520,7 @@ "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> { let arguments = (ins LLVM_ScalarOrVectorOf:$condition, LLVM_Type:$trueValue, LLVM_Type:$falseValue); + let results = (outs LLVM_Type:$res); let builders = [ OpBuilderDAG<(ins "Value":$condition, "Value":$lhs, "Value":$rhs), [{ @@ -508,6 +530,7 @@ } def LLVM_FreezeOp : LLVM_OneResultOp<"freeze", [SameOperandsAndResultType]> { let arguments = (ins LLVM_Type:$val); + let results = (outs LLVM_Type:$res); let assemblyFormat = "$val attr-dict `:` type($val)"; string llvmBuilder = "builder.CreateFreeze($val);"; } @@ -641,6 +664,7 @@ def LLVM_AddressOfOp : LLVM_OneResultOp<"mlir.addressof"> { let arguments = (ins FlatSymbolRefAttr:$global_name); + let results = (outs LLVM_Type:$res); let summary = "Creates a pointer pointing to a global or a function"; @@ -796,12 +820,13 @@ : LLVM_OneResultOp<"mlir.null", [NoSideEffect]>, LLVM_Builder<"$res = llvm::ConstantPointerNull::get(" " cast($_resultType));"> { + let results = (outs LLVM_AnyPointer:$res); let assemblyFormat = "attr-dict `:` type($res)"; - let verifier = [{ return ::verify(*this); }]; } def LLVM_UndefOp : LLVM_OneResultOp<"mlir.undef", [NoSideEffect]>, LLVM_Builder<"$res = llvm::UndefValue::get($_resultType);"> { + let results = (outs LLVM_Type:$res); let assemblyFormat = "attr-dict `:` type($res)"; } def LLVM_ConstantOp @@ -809,12 +834,12 @@ LLVM_Builder<"$res = getLLVMConstant($_resultType, $value, $_location);"> { let arguments = (ins AnyAttr:$value); + let results = (outs LLVM_Type:$res); let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)"; let verifier = [{ return ::verify(*this); }]; } -def LLVM_DialectCastOp : LLVM_Op<"mlir.cast", [NoSideEffect]>, - Results<(outs AnyType:$res)> { +def LLVM_DialectCastOp : LLVM_Op<"mlir.cast", [NoSideEffect]> { let summary = "Type cast between LLVM dialect and Standard."; let description = [{ llvm.mlir.cast op casts between Standard and LLVM dialects. It only changes @@ -828,6 +853,7 @@ llvm.mlir.cast %v : !llvm<"<2 x float>"> to vector<2xf32> }]; let arguments = (ins AnyType:$in); + let results = (outs AnyType:$res); let assemblyFormat = "$in attr-dict `:` type($in) `to` type($res)"; let verifier = "return ::verify(*this);"; } @@ -951,6 +977,7 @@ : LLVM_OneResultOp<"intr.matrix.column.major.load"> { let arguments = (ins LLVM_Type:$data, LLVM_Type:$stride, I1Attr:$isVolatile, I32Attr:$rows, I32Attr:$columns); + let results = (outs LLVM_Type:$res); string llvmBuilder = [{ llvm::MatrixBuilder mb(builder); const llvm::DataLayout &dl = @@ -997,6 +1024,7 @@ : LLVM_OneResultOp<"intr.matrix.multiply"> { let arguments = (ins LLVM_Type:$lhs, LLVM_Type:$rhs, I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns); + let results = (outs LLVM_Type:$res); string llvmBuilder = [{ llvm::MatrixBuilder mb(builder); $res = mb.CreateMatrixMultiply( @@ -1011,6 +1039,7 @@ /// `matrix`, as specified in the LLVM MatrixBuilder. def LLVM_MatrixTransposeOp : LLVM_OneResultOp<"intr.matrix.transpose"> { let arguments = (ins LLVM_Type:$matrix, I32Attr:$rows, I32Attr:$columns); + let results = (outs LLVM_Type:$res); string llvmBuilder = [{ llvm::MatrixBuilder mb(builder); $res = mb.CreateMatrixTranspose( @@ -1035,6 +1064,7 @@ def LLVM_MaskedLoadOp : LLVM_OneResultOp<"intr.masked.load"> { let arguments = (ins LLVM_Type:$data, LLVM_Type:$mask, Variadic:$pass_thru, I32Attr:$alignment); + let results = (outs LLVM_Type:$res); string llvmBuilder = [{ $res = $pass_thru.empty() ? builder.CreateMaskedLoad( $data, llvm::Align($alignment), $mask) : @@ -1061,6 +1091,7 @@ def LLVM_masked_gather : LLVM_OneResultOp<"intr.masked.gather"> { let arguments = (ins LLVM_Type:$ptrs, LLVM_Type:$mask, Variadic:$pass_thru, I32Attr:$alignment); + let results = (outs LLVM_Type:$res); string llvmBuilder = [{ $res = $pass_thru.empty() ? builder.CreateMaskedGather( $ptrs, llvm::Align($alignment), $mask) : @@ -1139,11 +1170,11 @@ def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyInteger]>; -def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw">, - Results<(outs LLVM_Type:$res)> { +def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw"> { let arguments = (ins AtomicBinOp:$bin_op, LLVM_PointerTo:$ptr, LLVM_AtomicRMWType:$val, AtomicOrdering:$ordering); + let results = (outs LLVM_AtomicRMWType:$res); let llvmBuilder = [{ $res = builder.CreateAtomicRMW(getLLVMAtomicBinOp($bin_op), $ptr, $val, getLLVMAtomicOrdering($ordering)); @@ -1154,12 +1185,24 @@ } def LLVM_AtomicCmpXchgType : AnyTypeOf<[LLVM_AnyInteger, LLVM_AnyPointer]>; - -def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg">, Results<(outs LLVM_Type:$res)> { +def LLVM_AtomicCmpXchgResultType : Type().getBody().size() == 2">, + SubstLeaves<"$_self", + "$_self.cast<::mlir::LLVM::LLVMStructType>().getBody()[0]", + LLVM_AtomicCmpXchgType.predicate>, + SubstLeaves<"$_self", + "$_self.cast<::mlir::LLVM::LLVMStructType>().getBody()[1]", + LLVM_i1.predicate>]>, + "an LLVM struct type with any integer or pointer followed by a single-bit " + "integer">; + +def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg"> { let arguments = (ins LLVM_PointerTo:$ptr, LLVM_AtomicCmpXchgType:$cmp, LLVM_AtomicCmpXchgType:$val, AtomicOrdering:$success_ordering, AtomicOrdering:$failure_ordering); + let results = (outs LLVM_AtomicCmpXchgResultType:$res); let llvmBuilder = [{ $res = builder.CreateAtomicCmpXchg($ptr, $cmp, $val, getLLVMAtomicOrdering($success_ordering), diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1510,18 +1510,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// Verification for LLVM::NullOp. -//===----------------------------------------------------------------------===// - -// Only LLVM pointer types are supported. -static LogicalResult verify(LLVM::NullOp op) { - auto llvmType = op.getType().dyn_cast(); - if (!llvmType || !llvmType.isPointerTy()) - return op.emitOpError("expected LLVM IR pointer type"); - return success(); -} - //===----------------------------------------------------------------------===// // Verification for LLVM::ConstantOp. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -339,7 +339,7 @@ // ----- func @null_non_llvm_type() { - // expected-error@+1 {{expected LLVM IR pointer type}} + // expected-error@+1 {{must be LLVM pointer type, but got '!llvm.i32'}} llvm.mlir.null : !llvm.i32 }