diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -779,6 +779,226 @@ }]; } +// +// Vector Reductions. +// + +def LLVM_experimental_vector_reduce_add : + LLVM_IntrOp<"experimental.vector.reduce.add", []>, + Arguments<(ins LLVM_Type)>, Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::experimental_vector_reduce_add, { + opInst.getOperand(0).getType().cast() + .getUnderlyingType(), + }); + auto operands = + lookupValues(opInst.getOperands()); + $res = builder.CreateCall(fn, operands); + }]; +} + +def LLVM_experimental_vector_reduce_and : + LLVM_IntrOp<"experimental.vector.reduce.and", []>, + Arguments<(ins LLVM_Type)>, Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::experimental_vector_reduce_and, { + opInst.getOperand(0).getType().cast() + .getUnderlyingType(), + }); + auto operands = + lookupValues(opInst.getOperands()); + $res = builder.CreateCall(fn, operands); + }]; +} + +def LLVM_experimental_vector_reduce_fmax : + LLVM_IntrOp<"experimental.vector.reduce.fmax", []>, + Arguments<(ins LLVM_Type)>, Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::experimental_vector_reduce_fmax, { + opInst.getOperand(0).getType().cast() + .getUnderlyingType(), + }); + auto operands = + lookupValues(opInst.getOperands()); + $res = builder.CreateCall(fn, operands); + }]; +} + +def LLVM_experimental_vector_reduce_fmin : + LLVM_IntrOp<"experimental.vector.reduce.fmin", []>, + Arguments<(ins LLVM_Type)>, Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::experimental_vector_reduce_fmin, { + opInst.getOperand(0).getType().cast() + .getUnderlyingType(), + }); + auto operands = + lookupValues(opInst.getOperands()); + $res = builder.CreateCall(fn, operands); + }]; +} + +def LLVM_experimental_vector_reduce_mul : + LLVM_IntrOp<"experimental.vector.reduce.mul", []>, + Arguments<(ins LLVM_Type)>, Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::experimental_vector_reduce_mul, { + opInst.getOperand(0).getType().cast() + .getUnderlyingType(), + }); + auto operands = + lookupValues(opInst.getOperands()); + $res = builder.CreateCall(fn, operands); + }]; +} + +def LLVM_experimental_vector_reduce_or : + LLVM_IntrOp<"experimental.vector.reduce.or", []>, + Arguments<(ins LLVM_Type)>, Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::experimental_vector_reduce_or, { + opInst.getOperand(0).getType().cast() + .getUnderlyingType(), + }); + auto operands = + lookupValues(opInst.getOperands()); + $res = builder.CreateCall(fn, operands); + }]; +} + +def LLVM_experimental_vector_reduce_smax : + LLVM_IntrOp<"experimental.vector.reduce.smax", []>, + Arguments<(ins LLVM_Type)>, Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::experimental_vector_reduce_smax, { + opInst.getOperand(0).getType().cast() + .getUnderlyingType(), + }); + auto operands = + lookupValues(opInst.getOperands()); + $res = builder.CreateCall(fn, operands); + }]; +} + +def LLVM_experimental_vector_reduce_smin : + LLVM_IntrOp<"experimental.vector.reduce.smin", []>, + Arguments<(ins LLVM_Type)>, Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::experimental_vector_reduce_smin, { + opInst.getOperand(0).getType().cast() + .getUnderlyingType(), + }); + auto operands = + lookupValues(opInst.getOperands()); + $res = builder.CreateCall(fn, operands); + }]; +} + +def LLVM_experimental_vector_reduce_umax : + LLVM_IntrOp<"experimental.vector.reduce.umax", []>, + Arguments<(ins LLVM_Type)>, Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::experimental_vector_reduce_umax, { + opInst.getOperand(0).getType().cast() + .getUnderlyingType(), + }); + auto operands = + lookupValues(opInst.getOperands()); + $res = builder.CreateCall(fn, operands); + }]; +} + +def LLVM_experimental_vector_reduce_umin : + LLVM_IntrOp<"experimental.vector.reduce.umin", []>, + Arguments<(ins LLVM_Type)>, Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::experimental_vector_reduce_umin, { + opInst.getOperand(0).getType().cast() + .getUnderlyingType(), + }); + auto operands = + lookupValues(opInst.getOperands()); + $res = builder.CreateCall(fn, operands); + }]; +} + +def LLVM_experimental_vector_reduce_v2_fadd : + LLVM_IntrOp<"experimental.vector.reduce.v2.fadd", []>, + Arguments<(ins LLVM_Type, LLVM_Type)>, Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::experimental_vector_reduce_v2_fadd, { + opInst.getResult(0).getType().cast() + .getUnderlyingType(), + opInst.getOperand(1).getType().cast() + .getUnderlyingType(), + }); + auto operands = + lookupValues(opInst.getOperands()); + $res = builder.CreateCall(fn, operands); + }]; +} + +def LLVM_experimental_vector_reduce_v2_fmul : + LLVM_IntrOp<"experimental.vector.reduce.v2.fmul", []>, + Arguments<(ins LLVM_Type, LLVM_Type)>, Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::experimental_vector_reduce_v2_fmul, { + opInst.getResult(0).getType().cast() + .getUnderlyingType(), + opInst.getOperand(1).getType().cast() + .getUnderlyingType(), + }); + auto operands = + lookupValues(opInst.getOperands()); + $res = builder.CreateCall(fn, operands); + }]; +} + +def LLVM_experimental_vector_reduce_xor : + LLVM_IntrOp<"experimental.vector.reduce.xor", []>, + Arguments<(ins LLVM_Type)>, Results<(outs LLVM_Type:$res)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, llvm::Intrinsic::experimental_vector_reduce_xor, { + opInst.getOperand(0).getType().cast() + .getUnderlyingType(), + }); + auto operands = + lookupValues(opInst.getOperands()); + $res = builder.CreateCall(fn, operands); + }]; +} + +// +// Atomic operations. +// + def AtomicBinOpXchg : I64EnumAttrCase<"xchg", 0>; def AtomicBinOpAdd : I64EnumAttrCase<"add", 1>; def AtomicBinOpSub : I64EnumAttrCase<"sub", 2>; diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -95,6 +95,37 @@ llvm.return } +// CHECK-LABEL: @vector_reductions +llvm.func @vector_reductions(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">, %arg2: !llvm<"<8 x i32>">) { + // CHECK: call i32 @llvm.experimental.vector.reduce.add.v8i32 + "llvm.intr.experimental.vector.reduce.add"(%arg2) : (!llvm<"<8 x i32>">) -> !llvm.i32 + // CHECK: call i32 @llvm.experimental.vector.reduce.and.v8i32 + "llvm.intr.experimental.vector.reduce.and"(%arg2) : (!llvm<"<8 x i32>">) -> !llvm.i32 + // CHECK: call float @llvm.experimental.vector.reduce.fmax.v8f32 + "llvm.intr.experimental.vector.reduce.fmax"(%arg1) : (!llvm<"<8 x float>">) -> !llvm.float + // CHECK: call float @llvm.experimental.vector.reduce.fmin.v8f32 + "llvm.intr.experimental.vector.reduce.fmin"(%arg1) : (!llvm<"<8 x float>">) -> !llvm.float + // CHECK: call i32 @llvm.experimental.vector.reduce.mul.v8i32 + "llvm.intr.experimental.vector.reduce.mul"(%arg2) : (!llvm<"<8 x i32>">) -> !llvm.i32 + // CHECK: call i32 @llvm.experimental.vector.reduce.or.v8i32 + "llvm.intr.experimental.vector.reduce.or"(%arg2) : (!llvm<"<8 x i32>">) -> !llvm.i32 + // CHECK: call i32 @llvm.experimental.vector.reduce.smax.v8i32 + "llvm.intr.experimental.vector.reduce.smax"(%arg2) : (!llvm<"<8 x i32>">) -> !llvm.i32 + // CHECK: call i32 @llvm.experimental.vector.reduce.smin.v8i32 + "llvm.intr.experimental.vector.reduce.smin"(%arg2) : (!llvm<"<8 x i32>">) -> !llvm.i32 + // CHECK: call i32 @llvm.experimental.vector.reduce.umax.v8i32 + "llvm.intr.experimental.vector.reduce.umax"(%arg2) : (!llvm<"<8 x i32>">) -> !llvm.i32 + // CHECK: call i32 @llvm.experimental.vector.reduce.umin.v8i32 + "llvm.intr.experimental.vector.reduce.umin"(%arg2) : (!llvm<"<8 x i32>">) -> !llvm.i32 + // CHECK: call float @llvm.experimental.vector.reduce.v2.fadd.f32.v8f32 + "llvm.intr.experimental.vector.reduce.v2.fadd"(%arg0, %arg1) : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float + // CHECK: call float @llvm.experimental.vector.reduce.v2.fmul.f32.v8f32 + "llvm.intr.experimental.vector.reduce.v2.fmul"(%arg0, %arg1) : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float + // CHECK: call i32 @llvm.experimental.vector.reduce.xor.v8i32 + "llvm.intr.experimental.vector.reduce.xor"(%arg2) : (!llvm<"<8 x i32>">) -> !llvm.i32 + llvm.return +} + // Check that intrinsics are declared with appropriate types. // CHECK: declare float @llvm.fmuladd.f32(float, float, float) // CHECK: declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>) #0