diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -560,8 +560,10 @@ // Header block to its merge (and continue) target mapping. BlockMergeInfoMap blockMergeInfo; - // Block to its phi (block argument) mapping. - DenseMap blockPhiInfo; + // For each pair of {predecessor, target} blocks, maps the pair of blocks to + // the list of phi arguments passed from predecessor to target. + DenseMap, BlockPhiInfo> + blockPhiInfo; // Result to value mapping. DenseMap valueMap; diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1573,7 +1573,8 @@ for (unsigned i = 2, e = operands.size(); i < e; i += 2) { uint32_t value = operands[i]; Block *predecessor = getOrCreateBlock(operands[i + 1]); - blockPhiInfo[predecessor].push_back(value); + std::pair predecessorTargetPair{predecessor, curBlock}; + blockPhiInfo[predecessorTargetPair].push_back(value); LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor << " with arg id = " << value << '\n'); } @@ -1853,7 +1854,8 @@ OpBuilder::InsertionGuard guard(opBuilder); for (const auto &info : blockPhiInfo) { - Block *block = info.first; + Block *block = info.first.first; + Block *target = info.first.second; const BlockPhiInfo &phiInfo = info.second; LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n"); LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n"); @@ -1882,6 +1884,24 @@ opBuilder.create(branchOp.getLoc(), branchOp.getTarget(), blockArgs); branchOp.erase(); + } else if (auto branchCondOp = dyn_cast(op)) { + assert((branchCondOp.getTrueBlock() == target || + branchCondOp.getFalseBlock() == target) && + "expected target to be either the true or false target"); + if (target == branchCondOp.trueTarget()) + opBuilder.create( + branchCondOp.getLoc(), branchCondOp.condition(), blockArgs, + branchCondOp.getFalseBlockArguments(), + branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(), + branchCondOp.falseTarget()); + else + opBuilder.create( + branchCondOp.getLoc(), branchCondOp.condition(), + branchCondOp.getTrueBlockArguments(), blockArgs, + branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(), + branchCondOp.getFalseBlock()); + + branchCondOp.erase(); } else { return emitError(unknownLoc, "unimplemented terminator for Phi creation"); } diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -959,7 +959,7 @@ // OpPhi | result type | result | (value , parent block ) pair // So we need to collect all predecessor blocks and the arguments they send // to this block. - SmallVector, 4> predecessors; + SmallVector, 4> predecessors; for (Block *predecessor : block->getPredecessors()) { auto *terminator = predecessor->getTerminator(); // The predecessor here is the immediate one according to MLIR's IR @@ -971,7 +971,21 @@ // structured control flow op's merge block. predecessor = getPhiIncomingBlock(predecessor); if (auto branchOp = dyn_cast(terminator)) { - predecessors.emplace_back(predecessor, branchOp.operand_begin()); + predecessors.emplace_back(predecessor, branchOp.getOperands()); + } else if (auto branchCondOp = + dyn_cast(terminator)) { + Optional blockOperands; + + for (auto successorIdx : + llvm::seq(0, predecessor->getNumSuccessors())) + if (predecessor->getSuccessors()[successorIdx] == block) { + blockOperands = branchCondOp.getSuccessorOperands(successorIdx); + break; + } + + assert(blockOperands && !blockOperands->empty() && + "expected non-empty block operand range"); + predecessors.emplace_back(predecessor, *blockOperands); } else { return terminator->emitError("unimplemented terminator for Phi creation"); } @@ -996,7 +1010,7 @@ phiArgs.push_back(phiID); for (auto predIndex : llvm::seq(0, predecessors.size())) { - Value value = *(predecessors[predIndex].second + argIndex); + Value value = predecessors[predIndex].second[argIndex]; uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId << ") value " << value << ' '); diff --git a/mlir/test/Target/SPIRV/phi.mlir b/mlir/test/Target/SPIRV/phi.mlir --- a/mlir/test/Target/SPIRV/phi.mlir +++ b/mlir/test/Target/SPIRV/phi.mlir @@ -286,3 +286,60 @@ spv.EntryPoint "GLCompute" @fmul_kernel spv.ExecutionMode @fmul_kernel "LocalSize", 32, 1, 1 } + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { +// CHECK-LABEL: @cond_branch_true_argument + spv.func @cond_branch_true_argument() -> () "None" { + %true = spv.Constant true + %zero = spv.Constant 0 : i32 + %one = spv.Constant 1 : i32 +// CHECK: spv.BranchConditional %{{.*}}, ^[[true1:.*]](%{{.*}}, %{{.*}} : i32, i32), ^[[false1:.*]] + spv.BranchConditional %true, ^true1(%zero, %zero: i32, i32), ^false1 +// CHECK: [[true1]](%{{.*}}: i32, %{{.*}}: i32) + ^true1(%arg0: i32, %arg1: i32): + spv.Return +// CHECK: [[false1]]: + ^false1: + spv.Return + } +} + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { +// CHECK-LABEL: @cond_branch_false_argument + spv.func @cond_branch_false_argument() -> () "None" { + %true = spv.Constant true + %zero = spv.Constant 0 : i32 + %one = spv.Constant 1 : i32 +// CHECK: spv.BranchConditional %{{.*}}, ^[[true1:.*]], ^[[false1:.*]](%{{.*}}, %{{.*}} : i32, i32) + spv.BranchConditional %true, ^true1, ^false1(%zero, %zero: i32, i32) +// CHECK: [[true1]]: + ^true1: + spv.Return +// CHECK: [[false1]](%{{.*}}: i32, %{{.*}}: i32): + ^false1(%arg0: i32, %arg1: i32): + spv.Return + } +} + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { +// CHECK-LABEL: @cond_branch_true_and_false_argument + spv.func @cond_branch_true_and_false_argument() -> () "None" { + %true = spv.Constant true + %zero = spv.Constant 0 : i32 + %one = spv.Constant 1 : i32 +// CHECK: spv.BranchConditional %{{.*}}, ^[[true1:.*]](%{{.*}} : i32), ^[[false1:.*]](%{{.*}}, %{{.*}} : i32, i32) + spv.BranchConditional %true, ^true1(%one: i32), ^false1(%zero, %zero: i32, i32) +// CHECK: [[true1]](%{{.*}}: i32): + ^true1(%arg0: i32): + spv.Return +// CHECK: [[false1]](%{{.*}}: i32, %{{.*}}: i32): + ^false1(%arg1: i32, %arg2: i32): + spv.Return + } +}