diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -295,7 +295,8 @@ static ParseResult parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, SmallVectorImpl> &modifiers, - Optional &chunkSize) { + Optional &chunkSize, + Type &chunkType) { if (parser.parseLParen()) return failure(); @@ -307,7 +308,7 @@ if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { if (succeeded(parser.parseOptionalEqual())) { chunkSize = OpAsmParser::OperandType{}; - if (parser.parseOperand(*chunkSize)) + if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType)) return failure(); } else { chunkSize = llvm::NoneType::None; @@ -341,7 +342,7 @@ Value scheduleChunkVar) { p << "schedule(" << stringifyClauseScheduleKind(sched).lower(); if (scheduleChunkVar) - p << " = " << scheduleChunkVar; + p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType(); if (modifier) p << ", " << stringifyScheduleModifier(*modifier); if (simd) @@ -631,6 +632,7 @@ SmallString<8> schedule; SmallVector> modifiers; Optional scheduleChunkSize; + Type scheduleChunkType; // Compute the position of clauses in operand segments int currPos = 0; @@ -751,7 +753,8 @@ clauseSegments[pos[linearClause] + 1] = linearSteps.size(); } else if (clauseKeyword == "schedule") { if (checkAllowed(scheduleClause) || - parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize)) + parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize, + scheduleChunkType)) return failure(); if (scheduleChunkSize) { clauseSegments[pos[scheduleClause]] = 1; @@ -906,10 +909,9 @@ result.addAttribute("simd_modifier", attr); } } - if (scheduleChunkSize) { - auto chunkSizeType = parser.getBuilder().getI32Type(); - parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands); - } + if (scheduleChunkSize) + parser.resolveOperand(*scheduleChunkSize, scheduleChunkType, + result.operands); } segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end()); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -646,10 +646,20 @@ // Find the loop configuration. llvm::Value *step = moduleTranslation.lookupValue(loop.step()[0]); llvm::Type *ivType = step->getType(); - llvm::Value *chunk = - loop.schedule_chunk_var() - ? moduleTranslation.lookupValue(loop.schedule_chunk_var()) - : llvm::ConstantInt::get(ivType, 1); + llvm::Value *chunk = nullptr; + if (loop.schedule_chunk_var()) { + llvm::Value *chunkVar = + moduleTranslation.lookupValue(loop.schedule_chunk_var()); + llvm::Type *chunkVarType = chunkVar->getType(); + assert(chunkVarType->isIntegerTy() && + "chunk size must be one integer expression"); + if (chunkVarType->getIntegerBitWidth() < ivType->getIntegerBitWidth()) + chunk = builder.CreateSExt(chunkVar, ivType); + else if (chunkVarType->getIntegerBitWidth() > ivType->getIntegerBitWidth()) + chunk = builder.CreateTrunc(chunkVar, ivType); + else + chunk = chunkVar; + } SmallVector reductionDecls; collectReductionDecls(loop, reductionDecls); diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -188,8 +188,7 @@ } // CHECK-LABEL: omp_wsloop_pretty -func @omp_wsloop_pretty(%lb : index, %ub : index, %step : index, - %data_var : memref, %linear_var : i32, %chunk_var : i32) -> () { +func @omp_wsloop_pretty(%lb : index, %ub : index, %step : index, %data_var : memref, %linear_var : i32, %chunk_var : i32, %chunk_var2 : i16) -> () { // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) private(%{{.*}} : memref) omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) private(%data_var : memref) collapse(2) ordered(2) { @@ -201,24 +200,24 @@ omp.yield } - // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) lastprivate(%{{.*}} : memref) linear(%{{.*}} = %{{.*}} : memref) schedule(static = %{{.*}}) collapse(3) ordered(2) + // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) lastprivate(%{{.*}} : memref) linear(%{{.*}} = %{{.*}} : memref) schedule(static = %{{.*}} : i32) collapse(3) ordered(2) omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) ordered(2) private(%data_var : memref) firstprivate(%data_var : memref) lastprivate(%data_var : memref) linear(%data_var = %linear_var : memref) - schedule(static = %chunk_var) collapse(3) { + schedule(static = %chunk_var : i32) collapse(3) { omp.yield } - // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) lastprivate(%{{.*}} : memref) linear(%{{.*}} = %{{.*}} : memref) schedule(dynamic = %{{.*}}, nonmonotonic) collapse(3) ordered(2) + // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) lastprivate(%{{.*}} : memref) linear(%{{.*}} = %{{.*}} : memref) schedule(dynamic = %{{.*}} : i32, nonmonotonic) collapse(3) ordered(2) omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) ordered(2) private(%data_var : memref) firstprivate(%data_var : memref) lastprivate(%data_var : memref) linear(%data_var = %linear_var : memref) - schedule(dynamic = %chunk_var, nonmonotonic) collapse(3) { + schedule(dynamic = %chunk_var : i32, nonmonotonic) collapse(3) { omp.yield } - // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) lastprivate(%{{.*}} : memref) linear(%{{.*}} = %{{.*}} : memref) schedule(dynamic = %{{.*}}, monotonic) collapse(3) ordered(2) + // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) lastprivate(%{{.*}} : memref) linear(%{{.*}} = %{{.*}} : memref) schedule(dynamic = %{{.*}} : i16, monotonic) collapse(3) ordered(2) omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) ordered(2) private(%data_var : memref) firstprivate(%data_var : memref) lastprivate(%data_var : memref) linear(%data_var = %linear_var : memref) - schedule(dynamic = %chunk_var, monotonic) collapse(3) { + schedule(dynamic = %chunk_var2 : i16, monotonic) collapse(3) { omp.yield } diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -430,7 +430,24 @@ // CHECK: call void @__kmpc_dispatch_init_8u // CHECK: %[[continue:.*]] = call i32 @__kmpc_dispatch_next_8u // CHECK: %[[cond:.*]] = icmp ne i32 %[[continue]], 0 - // CHECK br i1 %[[cond]], label %omp_loop.header{{.*}}, label %omp_loop.exit{{.*}} + // CHECK: br i1 %[[cond]], label %omp_loop.header{{.*}}, label %omp_loop.exit{{.*}} + llvm.call @body(%iv) : (i64) -> () + omp.yield + } + llvm.return +} + +// ----- + +llvm.func @body(i64) + +llvm.func @test_omp_wsloop_dynamic_chunk_const(%lb : i64, %ub : i64, %step : i64) -> () { + %chunk_size_const = llvm.mlir.constant(2 : i16) : i16 + omp.wsloop (%iv) : i64 = (%lb) to (%ub) step (%step) schedule(dynamic = %chunk_size_const : i16) { + // CHECK: call void @__kmpc_dispatch_init_8u(%struct.ident_t* @{{.*}}, i32 %{{.*}}, i32 35, i64 {{.*}}, i64 %{{.*}}, i64 {{.*}}, i64 2) + // CHECK: %[[continue:.*]] = call i32 @__kmpc_dispatch_next_8u + // CHECK: %[[cond:.*]] = icmp ne i32 %[[continue]], 0 + // CHECK: br i1 %[[cond]], label %omp_loop.header{{.*}}, label %omp_loop.exit{{.*}} llvm.call @body(%iv) : (i64) -> () omp.yield } @@ -439,6 +456,62 @@ // ----- +llvm.func @body(i32) + +llvm.func @test_omp_wsloop_dynamic_chunk_var(%lb : i32, %ub : i32, %step : i32) -> () { + %1 = llvm.mlir.constant(1 : i64) : i64 + %chunk_size_alloca = llvm.alloca %1 x i16 {bindc_name = "chunk_size", in_type = i16, uniq_name = "_QFsub1Echunk_size"} : (i64) -> !llvm.ptr + %chunk_size_var = llvm.load %chunk_size_alloca : !llvm.ptr + omp.wsloop (%iv) : i32 = (%lb) to (%ub) step (%step) schedule(dynamic = %chunk_size_var : i16) { + // CHECK: %[[CHUNK_SIZE:.*]] = sext i16 %{{.*}} to i32 + // CHECK: call void @__kmpc_dispatch_init_4u(%struct.ident_t* @{{.*}}, i32 %{{.*}}, i32 35, i32 {{.*}}, i32 %{{.*}}, i32 {{.*}}, i32 %[[CHUNK_SIZE]]) + // CHECK: %[[continue:.*]] = call i32 @__kmpc_dispatch_next_4u + // CHECK: %[[cond:.*]] = icmp ne i32 %[[continue]], 0 + // CHECK: br i1 %[[cond]], label %omp_loop.header{{.*}}, label %omp_loop.exit{{.*}} + llvm.call @body(%iv) : (i32) -> () + omp.yield + } + llvm.return +} + +// ----- + +llvm.func @body(i32) + +llvm.func @test_omp_wsloop_dynamic_chunk_var2(%lb : i32, %ub : i32, %step : i32) -> () { + %1 = llvm.mlir.constant(1 : i64) : i64 + %chunk_size_alloca = llvm.alloca %1 x i64 {bindc_name = "chunk_size", in_type = i64, uniq_name = "_QFsub1Echunk_size"} : (i64) -> !llvm.ptr + %chunk_size_var = llvm.load %chunk_size_alloca : !llvm.ptr + omp.wsloop (%iv) : i32 = (%lb) to (%ub) step (%step) schedule(dynamic = %chunk_size_var : i64) { + // CHECK: %[[CHUNK_SIZE:.*]] = trunc i64 %{{.*}} to i32 + // CHECK: call void @__kmpc_dispatch_init_4u(%struct.ident_t* @{{.*}}, i32 %{{.*}}, i32 35, i32 {{.*}}, i32 %{{.*}}, i32 {{.*}}, i32 %[[CHUNK_SIZE]]) + // CHECK: %[[continue:.*]] = call i32 @__kmpc_dispatch_next_4u + // CHECK: %[[cond:.*]] = icmp ne i32 %[[continue]], 0 + // CHECK: br i1 %[[cond]], label %omp_loop.header{{.*}}, label %omp_loop.exit{{.*}} + llvm.call @body(%iv) : (i32) -> () + omp.yield + } + llvm.return +} + +// ----- + +llvm.func @body(i32) + +llvm.func @test_omp_wsloop_dynamic_chunk_var3(%lb : i32, %ub : i32, %step : i32, %chunk_size : i32) -> () { + omp.wsloop (%iv) : i32 = (%lb) to (%ub) step (%step) schedule(dynamic = %chunk_size : i32) { + // CHECK: call void @__kmpc_dispatch_init_4u(%struct.ident_t* @{{.*}}, i32 %{{.*}}, i32 35, i32 {{.*}}, i32 %{{.*}}, i32 {{.*}}, i32 %{{.*}}) + // CHECK: %[[continue:.*]] = call i32 @__kmpc_dispatch_next_4u + // CHECK: %[[cond:.*]] = icmp ne i32 %[[continue]], 0 + // CHECK: br i1 %[[cond]], label %omp_loop.header{{.*}}, label %omp_loop.exit{{.*}} + llvm.call @body(%iv) : (i32) -> () + omp.yield + } + llvm.return +} + +// ----- + llvm.func @body(i64) llvm.func @test_omp_wsloop_auto(%lb : i64, %ub : i64, %step : i64) -> () {