diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -623,7 +623,8 @@ def LLVM_CallOp : LLVM_Op<"call", [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Call to an LLVM function."; let description = [{ @@ -677,8 +678,8 @@ StringAttr::get($_builder.getContext(), callee), operands); }]>]; let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; } + def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> { let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position); let results = (outs LLVM_Type:$res); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1157,7 +1157,7 @@ return getOperands().drop_front(getCallee().has_value() ? 0 : 1); } -LogicalResult CallOp::verify() { +LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { if (getNumResults() > 1) return emitOpError("must have 0 or 1 result"); @@ -1181,7 +1181,7 @@ fnType = ptrType.getElementType(); } else { Operation *callee = - SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr()); + symbolTable.lookupNearestSymbolFrom(*this, calleeName.getAttr()); if (!callee) return emitOpError() << "'" << calleeName.getValue() diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -184,6 +184,7 @@ func.func @call_non_function_type(%callee : !llvm.func, %arg : i8) { // expected-error@+1 {{expected function type}} llvm.call %callee(%arg) : !llvm.func + llvm.return } // ----- @@ -191,6 +192,7 @@ func.func @invalid_call() { // expected-error@+1 {{'llvm.call' op must have either a `callee` attribute or at least an operand}} "llvm.call"() : () -> () + llvm.return } // ----- @@ -198,6 +200,7 @@ func.func @call_non_function_type(%callee : !llvm.func, %arg : i8) { // expected-error@+1 {{expected function type}} llvm.call %callee(%arg) : !llvm.func + llvm.return } // ----- @@ -205,6 +208,7 @@ func.func @call_unknown_symbol() { // expected-error@+1 {{'llvm.call' op 'missing_callee' does not reference a symbol in the current scope}} llvm.call @missing_callee() : () -> () + llvm.return } // ----- @@ -214,6 +218,7 @@ func.func @call_non_llvm() { // expected-error@+1 {{'llvm.call' op 'standard_func_callee' does not reference a valid LLVM function}} llvm.call @standard_func_callee() : () -> () + llvm.return } // ----- @@ -221,6 +226,7 @@ func.func @call_non_llvm_indirect(%arg0 : tensor<*xi32>) { // expected-error@+1 {{'llvm.call' op operand #0 must be LLVM dialect-compatible type}} "llvm.call"(%arg0) : (tensor<*xi32>) -> () + llvm.return } // ----- @@ -230,6 +236,7 @@ func.func @callee_arg_mismatch(%arg0 : i32) { // expected-error@+1 {{'llvm.call' op operand type mismatch for operand 0: 'i32' != 'i8'}} llvm.call @callee_func(%arg0) : (i32) -> () + llvm.return } // ----- @@ -237,6 +244,7 @@ func.func @indirect_callee_arg_mismatch(%arg0 : i32, %callee : !llvm.ptr>) { // expected-error@+1 {{'llvm.call' op operand type mismatch for operand 0: 'i32' != 'i8'}} "llvm.call"(%callee, %arg0) : (!llvm.ptr>, i32) -> () + llvm.return } // ----- @@ -246,6 +254,7 @@ func.func @callee_return_mismatch() { // expected-error@+1 {{'llvm.call' op result type mismatch: 'i32' != 'i8'}} %res = llvm.call @callee_func() : () -> (i32) + llvm.return } // ----- @@ -253,6 +262,7 @@ func.func @indirect_callee_return_mismatch(%callee : !llvm.ptr>) { // expected-error@+1 {{'llvm.call' op result type mismatch: 'i32' != 'i8'}} "llvm.call"(%callee) : (!llvm.ptr>) -> (i32) + llvm.return } // ----- @@ -260,6 +270,7 @@ func.func @call_too_many_results(%callee : () -> (i32,i32)) { // expected-error@+1 {{expected function with 0 or 1 result}} llvm.call %callee() : () -> (i32, i32) + llvm.return } // ----- @@ -267,6 +278,7 @@ func.func @call_non_llvm_result(%callee : () -> (tensor<*xi32>)) { // expected-error@+1 {{expected result to have LLVM type}} llvm.call %callee() : () -> (tensor<*xi32>) + llvm.return } // ----- @@ -274,6 +286,7 @@ func.func @call_non_llvm_input(%callee : (tensor<*xi32>) -> (), %arg : tensor<*xi32>) { // expected-error@+1 {{expected LLVM types as inputs}} llvm.call %callee(%arg) : (tensor<*xi32>) -> () + llvm.return } // ----- @@ -540,7 +553,7 @@ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) { // expected-error@+1 {{Could not match types for the A operands; expected one of 2xvector<2xf16> but got f16, f16}} - %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] {layoutA=#nvvm.mma_layout, layoutB=#nvvm.mma_layout, shape = #nvvm.shape} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } @@ -564,7 +577,7 @@ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) { // expected-error@+1 {{op requires attribute 'layoutA'}} - %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] {shape = #nvvm.shape}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } @@ -587,7 +600,7 @@ // expected-error@+1 {{op requires b1Op attribute}} %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, - multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> }