diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -78,6 +78,27 @@ } }; +namespace impl { +/// Unrolls op if it's operating on vectors. +LogicalResult unrollVectorOp(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter, + LLVMTypeConverter &converter); +} // namespace impl + +/// Rewriting that unrolls SourceOp if it's operating vectors. +template +struct UnrollVectorOpLowering : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return impl::unrollVectorOp(op, adaptor.getOperands(), rewriter, + *this->getTypeConverter()); + } +}; + } // namespace mlir #endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_ diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -9,6 +9,7 @@ #include "GPUOpsLowering.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" @@ -355,3 +356,45 @@ rewriter.eraseOp(gpuPrintfOp); return success(); } + +/// Unrolls op if it's operating on vectors. +LogicalResult impl::unrollVectorOp(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter, + LLVMTypeConverter &converter) { + TypeRange operandTypes(operands); + if (llvm::none_of(operandTypes, + [](Type type) { return type.isa(); })) { + return rewriter.notifyMatchFailure(op, "expected vector operand"); + } + if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0) + return rewriter.notifyMatchFailure(op, "expected no region/successor"); + if (op->getNumResults() != 1) + return rewriter.notifyMatchFailure(op, "expected single result"); + VectorType vectorType = op->getResult(0).getType().dyn_cast(); + if (!vectorType) + return rewriter.notifyMatchFailure(op, "expected vector result"); + + Location loc = op->getLoc(); + Value result = rewriter.create(loc, vectorType); + Type indexType = converter.convertType(rewriter.getIndexType()); + StringAttr name = op->getName().getIdentifier(); + Type elementType = vectorType.getElementType(); + + for (int64_t i = 0; i < vectorType.getNumElements(); ++i) { + Value index = rewriter.create(loc, indexType, i); + auto extractElement = [&](Value operand) -> Value { + if (!operand.getType().isa()) + return operand; + return rewriter.create(loc, operand, index); + }; + auto scalarOperands = + llvm::to_vector(llvm::map_range(operands, extractElement)); + Operation *scalarOp = + rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs()); + rewriter.create(loc, result, scalarOp->getResult(0), + index); + } + + rewriter.replaceOp(op, result); + return success(); +} diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -20,7 +20,6 @@ #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" @@ -28,10 +27,8 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/Support/FormatVariadic.h" #include "../GPUCommon/GPUOpsLowering.h" #include "../GPUCommon/IndexIntrinsicsOpLowering.h" @@ -254,42 +251,30 @@ StringAttr::get(&converter.getContext(), NVVM::NVVMDialect::getKernelFuncAttrName())); - patterns.add>(converter, "__nv_fabsf", - "__nv_fabs"); - patterns.add>(converter, "__nv_atanf", - "__nv_atan"); - patterns.add>(converter, "__nv_atan2f", - "__nv_atan2"); - patterns.add>(converter, "__nv_ceilf", - "__nv_ceil"); - patterns.add>(converter, "__nv_cosf", - "__nv_cos"); - patterns.add>(converter, "__nv_expf", - "__nv_exp"); - patterns.add>(converter, "__nv_exp2f", - "__nv_exp2"); - patterns.add>(converter, "__nv_expm1f", - "__nv_expm1"); - patterns.add>(converter, "__nv_floorf", - "__nv_floor"); - patterns.add>(converter, "__nv_logf", - "__nv_log"); - patterns.add>(converter, "__nv_log1pf", - "__nv_log1p"); - patterns.add>(converter, "__nv_log10f", - "__nv_log10"); - patterns.add>(converter, "__nv_log2f", - "__nv_log2"); - patterns.add>(converter, "__nv_powf", - "__nv_pow"); - patterns.add>(converter, "__nv_rsqrtf", - "__nv_rsqrt"); - patterns.add>(converter, "__nv_sinf", - "__nv_sin"); - patterns.add>(converter, "__nv_sqrtf", - "__nv_sqrt"); - patterns.add>(converter, "__nv_tanhf", - "__nv_tanh"); + auto addOpLowering = [&](auto dummy, StringRef f32Func, StringRef f64Func) { + using OpTy = decltype(dummy); + patterns.add>(converter); + patterns.add>(converter, f32Func, f64Func); + }; + + addOpLowering(math::AbsFOp{}, "__nv_fabsf", "__nv_fabs"); + addOpLowering(math::AtanOp{}, "__nv_atanf", "__nv_atan"); + addOpLowering(math::Atan2Op{}, "__nv_atan2f", "__nv_atan2"); + addOpLowering(math::CeilOp{}, "__nv_ceilf", "__nv_ceil"); + addOpLowering(math::CosOp{}, "__nv_cosf", "__nv_cos"); + addOpLowering(math::ExpOp{}, "__nv_expf", "__nv_exp"); + addOpLowering(math::Exp2Op{}, "__nv_exp2f", "__nv_exp2"); + addOpLowering(math::ExpM1Op{}, "__nv_expm1f", "__nv_expm1"); + addOpLowering(math::FloorOp{}, "__nv_floorf", "__nv_floor"); + addOpLowering(math::LogOp{}, "__nv_logf", "__nv_log"); + addOpLowering(math::Log1pOp{}, "__nv_log1pf", "__nv_log1p"); + addOpLowering(math::Log10Op{}, "__nv_log10f", "__nv_log10"); + addOpLowering(math::Log2Op{}, "__nv_log2f", "__nv_log2"); + addOpLowering(math::PowFOp{}, "__nv_powf", "__nv_pow"); + addOpLowering(math::RsqrtOp{}, "__nv_rsqrtf", "__nv_rsqrt"); + addOpLowering(math::SinOp{}, "__nv_sinf", "__nv_sin"); + addOpLowering(math::SqrtOp{}, "__nv_sqrtf", "__nv_sqrt"); + addOpLowering(math::TanhOp{}, "__nv_tanhf", "__nv_tanh"); } std::unique_ptr> diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -184,42 +184,30 @@ patterns.add(converter, /*addressSpace=*/4); } - patterns.add>(converter, "__ocml_fabs_f32", - "__ocml_fabs_f64"); - patterns.add>(converter, "__ocml_atan_f32", - "__ocml_atan_f64"); - patterns.add>( - converter, "__ocml_atan2_f32", "__ocml_atan2_f64"); - patterns.add>(converter, "__ocml_ceil_f32", - "__ocml_ceil_f64"); - patterns.add>(converter, "__ocml_cos_f32", - "__ocml_cos_f64"); - patterns.add>(converter, "__ocml_exp_f32", - "__ocml_exp_f64"); - patterns.add>(converter, "__ocml_exp2_f32", - "__ocml_exp2_f64"); - patterns.add>( - converter, "__ocml_expm1_f32", "__ocml_expm1_f64"); - patterns.add>( - converter, "__ocml_floor_f32", "__ocml_floor_f64"); - patterns.add>(converter, "__ocml_log_f32", - "__ocml_log_f64"); - patterns.add>( - converter, "__ocml_log10_f32", "__ocml_log10_f64"); - patterns.add>( - converter, "__ocml_log1p_f32", "__ocml_log1p_f64"); - patterns.add>(converter, "__ocml_log2_f32", - "__ocml_log2_f64"); - patterns.add>(converter, "__ocml_pow_f32", - "__ocml_pow_f64"); - patterns.add>( - converter, "__ocml_rsqrt_f32", "__ocml_rsqrt_f64"); - patterns.add>(converter, "__ocml_sin_f32", - "__ocml_sin_f64"); - patterns.add>(converter, "__ocml_sqrt_f32", - "__ocml_sqrt_f64"); - patterns.add>(converter, "__ocml_tanh_f32", - "__ocml_tanh_f64"); + auto addOpLowering = [&](auto dummy, StringRef f32Func, StringRef f64Func) { + using OpTy = decltype(dummy); + patterns.add>(converter); + patterns.add>(converter, f32Func, f64Func); + }; + + addOpLowering(math::AbsFOp{}, "__ocml_fabs_f32", "__ocml_fabs_f64"); + addOpLowering(math::AtanOp{}, "__ocml_atan_f32", "__ocml_atan_f64"); + addOpLowering(math::Atan2Op{}, "__ocml_atan2_f32", "__ocml_atan2_f64"); + addOpLowering(math::CeilOp{}, "__ocml_ceil_f32", "__ocml_ceil_f64"); + addOpLowering(math::CosOp{}, "__ocml_cos_f32", "__ocml_cos_f64"); + addOpLowering(math::ExpOp{}, "__ocml_exp_f32", "__ocml_exp_f64"); + addOpLowering(math::Exp2Op{}, "__ocml_exp2_f32", "__ocml_exp2_f64"); + addOpLowering(math::ExpM1Op{}, "__ocml_expm1_f32", "__ocml_expm1_f64"); + addOpLowering(math::FloorOp{}, "__ocml_floor_f32", "__ocml_floor_f64"); + addOpLowering(math::LogOp{}, "__ocml_log_f32", "__ocml_log_f64"); + addOpLowering(math::Log10Op{}, "__ocml_log10_f32", "__ocml_log10_f64"); + addOpLowering(math::Log1pOp{}, "__ocml_log1p_f32", "__ocml_log1p_f64"); + addOpLowering(math::Log2Op{}, "__ocml_log2_f32", "__ocml_log2_f64"); + addOpLowering(math::PowFOp{}, "__ocml_pow_f32", "__ocml_pow_f64"); + addOpLowering(math::RsqrtOp{}, "__ocml_rsqrt_f32", "__ocml_rsqrt_f64"); + addOpLowering(math::SinOp{}, "__ocml_sin_f32", "__ocml_sin_f64"); + addOpLowering(math::SqrtOp{}, "__ocml_sqrt_f32", "__ocml_sqrt_f64"); + addOpLowering(math::TanhOp{}, "__ocml_tanh_f32", "__ocml_tanh_f64"); } std::unique_ptr> diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -478,6 +478,20 @@ // ----- +gpu.module @test_module { + // CHECK-LABEL: func @gpu_unroll + func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> { + %result = math.exp %arg0 : vector<4xf32> + // CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32 + // CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32 + // CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32 + // CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32 + func.return %result : vector<4xf32> + } +} + +// ----- + gpu.module @test_module { // CHECK-LABEL: @kernel_func // CHECK: attributes diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir @@ -377,6 +377,20 @@ // ----- +gpu.module @test_module { + // CHECK-LABEL: func @gpu_unroll + func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> { + %result = math.exp %arg0 : vector<4xf32> + // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 + // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 + // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 + // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 + func.return %result : vector<4xf32> + } +} + +// ----- + gpu.module @test_module { // CHECK-LABEL: @kernel_func // CHECK: attributes