diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -536,15 +536,14 @@ return getBody()->getArguments().drop_front(getRank()); } - /// Return the thread indices in the order specified by the - /// given mapping argument. Return failure is - /// mapping is not a valid permutation. - FailureOr> getPermutedThreadIndices(ArrayRef mapping); - - /// Return the number of threads in the order specified by the - /// given mapping argument. - /// Return failure is mapping is not a valid permutation. - FailureOr> getPermutedNumThreads(OpBuilder &b, ArrayRef mapping); + /// Helper to sort `values` according to matching `keys`. + /// Take a custom `compare` binary comparator which returns true if the first + /// element is smaller than the second (i.e. compatible with std::sort). + /// This is a helper typically used to sort numThreads values before they are + /// mapped to concrete physical dimensions of hardware. + static SmallVector getValuesSortedByKey( + ArrayRef keys, ValueRange values, + llvm::function_ref compare); // The ensureTerminator method generated by SingleBlockImplicitTerminator is // unaware of the fact that our terminator also needs a region to be diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Value.h" #include "llvm/ADT/None.h" @@ -157,45 +158,75 @@ SmallVectorImpl &)> blockIdGenerator, SmallVectorImpl &gridDims, TransformOpInterface transformOp) { + // Step 0. Target-specific verifications. There is no good place to anchor + // those right now: the ForeachThreadOp is target-independent and the + // transform op does not apply to individual ForeachThreadOp. + MLIRContext *ctx = foreachThreadOp->getContext(); + Location loc = foreachThreadOp->getLoc(); + Attribute bX = GPUBlockMappingAttr::get(ctx, Blocks::DimX); + Attribute bY = GPUBlockMappingAttr::get(ctx, Blocks::DimY); + Attribute bZ = GPUBlockMappingAttr::get(ctx, Blocks::DimZ); if (foreachThreadOp.getNumResults() > 0) return transformOp.emitSilenceableError() - << "only bufferized scf.foreach_thread lowers to gpu.block_id"; + << "only bufferized scf.foreach_thread lowers to " + "gpu.block_id"; if (foreachThreadOp.getNumThreads().size() > 3) return transformOp.emitSilenceableError() - << "scf.foreach_thread with rank > 3 does not lower to gpu.block_id"; - - // Step 0. Outline the compute workload region and set up the workload - // operands. - SmallVector mapping; + << "scf.foreach_thread with rank > 3 does not lower to " + "gpu.block_id"; + if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) { + return !v.getDefiningOp(); + })) { + return transformOp.emitSilenceableError() + << "unsupported dynamic griddim size"; + } if (!foreachThreadOp.getMapping().has_value()) return transformOp.emitSilenceableError() << "mapping must be present"; - for (DeviceMappingAttrInterface map : - foreachThreadOp.getMapping()->getValue()) { - if (auto blockMap = map.dyn_cast()) { - mapping.push_back((int64_t)blockMap.getBlock()); - } else { - return transformOp.emitSilenceableError() - << "mapping must be #gpu.block"; - } + SmallVector blockMapping = + llvm::to_vector(foreachThreadOp.getMapping()->getValue()); + if (llvm::any_of(blockMapping, [](DeviceMappingAttrInterface map) { + return !map.isa(); + })) { + return transformOp.emitSilenceableError() + << "mapping must be #gpu.block"; } - FailureOr> potentialGridDim = - foreachThreadOp.getPermutedNumThreads(rewriter, mapping); - - if (failed(potentialGridDim) || - llvm::any_of(*potentialGridDim, [](OpFoldResult ofr) { - return !getConstantIntValue(ofr).has_value(); - })) { - return transformOp.emitSilenceableError() << "unsupported dynamic gridDim"; + // Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary. + SmallVector numBlocks = + llvm::to_vector(foreachThreadOp.getNumThreads()); + // Ensure we have 3 block sizes, one for each id. + Value one; + for (auto attr : {bX, bY, bZ}) { + if (std::find(blockMapping.begin(), blockMapping.end(), attr) == + blockMapping.end()) { + blockMapping.push_back(attr); + one = one ? one : rewriter.create(loc, 1); + numBlocks.push_back(one); + } } - for (OpFoldResult ofr : *potentialGridDim) - gridDims.push_back(getConstantIntValue(ofr).value()); + // Step 2. sort the values by the corresponding GPUBlockMappingAttr. + auto comparator = [](Attribute a, Attribute b) -> bool { + return static_cast(a.cast().getBlock()) < + static_cast(b.cast().getBlock()); + }; + SmallVector gridDimValues = scf::ForeachThreadOp::getValuesSortedByKey( + blockMapping, numBlocks, comparator); + for (Value v : gridDimValues) + gridDims.push_back(v.getDefiningOp().value()); + // Step 3. Generate the blockIds using the provided generator and map the + // induction variables to the newly created ops. SmallVector blockOps; blockIdGenerator(rewriter, foreachThreadOp, blockOps); + BlockAndValueMapping bvm; + for (auto [blockIdx, blockDim] : + llvm::zip(foreachThreadOp.getThreadIndices(), blockMapping)) { + bvm.map(blockIdx, blockOps[static_cast( + blockDim.cast().getBlock())]); + } - // Step 1. Move the body of foreachThreadOp. + // Step 4. Move the body of foreachThreadOp. // Erase the terminator first, it will not be used since we are on buffers. rewriter.eraseOp(foreachThreadOp.getTerminator()); Block *targetBlock = foreachThreadOp->getBlock(); @@ -204,20 +235,16 @@ targetBlock->getOperations().splice(insertionPoint, sourceBlock.getOperations()); - // Step 2. RAUW thread indices to thread ops. - SmallVector threadIndices = - *foreachThreadOp.getPermutedThreadIndices(mapping); - assert(blockOps.size() == 3 && "3 block id ops are required"); - for (auto [blockIdx, blockOp] : llvm::zip(threadIndices, blockOps)) { - Value val = blockIdx; - Value blkOp = blockOp; - if (!val) - continue; - for (Operation *user : llvm::make_early_inc_range(val.getUsers())) - user->replaceUsesOfWith(val, blkOp); + // Step 5. RAUW thread indices to thread ops. + for (Value blockIdx : foreachThreadOp.getThreadIndices()) { + for (Operation *user : llvm::make_early_inc_range(blockIdx.getUsers())) { + rewriter.updateRootInPlace(user, [&]() { + user->replaceUsesOfWith(blockIdx, bvm.lookup(blockIdx)); + }); + } } - // Step 3. Erase old op. + // Step 6. Erase old op. rewriter.eraseOp(foreachThreadOp); return DiagnosedSilenceableFailure::success(); @@ -252,11 +279,10 @@ OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(foreachOp); IndexType indexType = rewriter.getIndexType(); - SmallVector gpuDims{Dimension::x, Dimension::y, Dimension::z}; - for (int64_t idx : llvm::seq(0, gpuDims.size())) { - blockOps.push_back( - rewriter.create(loc, indexType, gpuDims[idx])); - } + blockOps = SmallVector{ + rewriter.create(loc, indexType, Dimension::x), + rewriter.create(loc, indexType, Dimension::y), + rewriter.create(loc, indexType, Dimension::z)}; } DiagnosedSilenceableFailure @@ -333,6 +359,9 @@ RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp, const SmallVectorImpl &globalBlockDims, bool syncAfterDistribute, llvm::Optional transformOp) { + // Step 0. Target-specific verifications. There is no good place to anchor + // those right now: the ForeachThreadOp is target-independent and the + // transform op does not apply to individual ForeachThreadOp. auto failureHelper = [&](const Twine &message) -> DiagnosedSilenceableFailure { if (transformOp.has_value()) { @@ -340,54 +369,79 @@ } return emitDefiniteFailure(foreachThreadOp, message); }; - + MLIRContext *ctx = foreachThreadOp->getContext(); + Location loc = foreachThreadOp->getLoc(); + Attribute tX = GPUThreadMappingAttr::get(ctx, Threads::DimX); + Attribute tY = GPUThreadMappingAttr::get(ctx, Threads::DimY); + Attribute tZ = GPUThreadMappingAttr::get(ctx, Threads::DimZ); if (foreachThreadOp.getNumResults() > 0) return failureHelper( "only bufferized scf.foreach_thread lowers to gpu.thread_id"); - if (foreachThreadOp.getNumThreads().size() > 3) return failureHelper( "scf.foreach_thread with rank > 3 does not lower to gpu.thread_id"); - - SmallVector mapping; + if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) { + return !v.getDefiningOp(); + })) { + return failureHelper("unsupported dynamic blockdim size"); + } if (!foreachThreadOp.getMapping().has_value()) return failureHelper("mapping must be present"); - for (DeviceMappingAttrInterface map : - foreachThreadOp.getMapping()->getValue()) { - if (auto threadMap = map.dyn_cast()) { - mapping.push_back((int64_t)threadMap.getThread()); - } else { - return failureHelper("mapping must be #gpu.thread"); - } - } - FailureOr> potentialBlockDim = - foreachThreadOp.getPermutedNumThreads(rewriter, mapping); - if (failed(potentialBlockDim) || - llvm::any_of(*potentialBlockDim, [](OpFoldResult ofr) { - return !getConstantIntValue(ofr).has_value(); + SmallVector threadMapping = + llvm::to_vector(foreachThreadOp.getMapping()->getValue()); + if (llvm::any_of(threadMapping, [](DeviceMappingAttrInterface map) { + return !map.isa(); })) { - return failureHelper("unsupported dynamic blockdim size"); + return transformOp->emitSilenceableError() + << "mapping must be #gpu.thread"; } - SmallVector blockDim = - llvm::to_vector(llvm::map_range(*potentialBlockDim, [](OpFoldResult ofr) { - return getConstantIntValue(ofr).value(); + // Step 1. Complete the threadMapping to a full mapping (with 1s) if + // necessary. + SmallVector numThreads = + llvm::to_vector(foreachThreadOp.getNumThreads()); + // Ensure we have 3 block sizes, one for each id. + Value one; + for (auto attr : {tX, tY, tZ}) { + if (std::find(threadMapping.begin(), threadMapping.end(), attr) == + threadMapping.end()) { + threadMapping.push_back(attr); + one = one ? one : rewriter.create(loc, 1); + numThreads.push_back(one); + } + } + + // Step 2. sort the values by the corresponding GPUThreadMappingAttr. + auto comparator = [](Attribute a, Attribute b) -> bool { + return static_cast(a.cast().getThread()) < + static_cast(b.cast().getThread()); + }; + SmallVector blockDimValues = + scf::ForeachThreadOp::getValuesSortedByKey(threadMapping, numThreads, + comparator); + SmallVector blockDims = + llvm::to_vector(llvm::map_range(blockDimValues, [](Value v) { + return v.getDefiningOp().value(); })); - // Step 1. Create the gpu.thread ops - Location loc = foreachThreadOp.getLoc(); + // Step 3. Create the gpu.thread ops and map the induction variables to the + // newly created ops. IndexType indexType = rewriter.getIndexType(); - - SmallVector gpuDims{Dimension::x, Dimension::y, Dimension::z}; - SmallVector threadOps; - for (int64_t idx : llvm::seq(0, blockDim.size())) { - threadOps.push_back( - rewriter.create(loc, indexType, gpuDims[idx])); + SmallVector threadOps{ + rewriter.create(loc, indexType, Dimension::x), + rewriter.create(loc, indexType, Dimension::y), + rewriter.create(loc, indexType, Dimension::z)}; + BlockAndValueMapping bvm; + for (auto [blockIdx, blockDim] : + llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) { + bvm.map(blockIdx, threadOps[static_cast( + blockDim.cast().getThread())]); } - // Step 2. Maybe create conditionals to predicate the region. + + // Step 4. Maybe create conditionals to predicate the region. Value predicate; for (auto [threadId, blockDim, globalBlockDim] : - llvm::zip(threadOps, blockDim, globalBlockDims)) { + llvm::zip(threadOps, blockDims, globalBlockDims)) { if (blockDim > globalBlockDim) { return failureHelper( "The requested GPU threads are fewer than the number of loop trip " @@ -404,19 +458,19 @@ : tmpPredicate; } - // Step 3. Move the body of foreachThreadOp. + // Step 5. Move the body of foreachThreadOp. // Erase the terminator first, it will not be used. rewriter.eraseOp(foreachThreadOp.getTerminator()); Block *targetBlock; Block::iterator insertionPoint; if (predicate) { - // Step 3.a. If predicated, move at the beginning. + // Step 5.a. If predicated, move at the beginning. auto ifOp = rewriter.create(loc, predicate, /*withElseRegion=*/false); targetBlock = ifOp.thenBlock(); insertionPoint = ifOp.thenBlock()->begin(); } else { - // Step 3.a. Otherwise, move inline just before foreachThreadOp. + // Step 5.b. Otherwise, move inline just before foreachThreadOp. targetBlock = foreachThreadOp->getBlock(); insertionPoint = Block::iterator(foreachThreadOp); } @@ -424,25 +478,21 @@ targetBlock->getOperations().splice(insertionPoint, sourceBlock.getOperations()); - // Step 4. RAUW thread indices to thread ops. - SmallVector threadIndices = - *foreachThreadOp.getPermutedThreadIndices(mapping); - for (auto [threadIdx, threadOp] : llvm::zip(threadIndices, threadOps)) { - Value val = threadIdx; - Value op = threadOp; - if (!val) - continue; - for (Operation *user : llvm::make_early_inc_range(val.getUsers())) { - user->replaceUsesOfWith(val, op); + // Step 6. RAUW thread indices to thread ops. + for (Value threadIdx : foreachThreadOp.getThreadIndices()) { + for (Operation *user : llvm::make_early_inc_range(threadIdx.getUsers())) { + rewriter.updateRootInPlace(user, [&]() { + user->replaceUsesOfWith(threadIdx, bvm.lookup(threadIdx)); + }); } } - // Step 5. syncthreads. + // Step 7. syncthreads. // TODO: Need warpsync if (syncAfterDistribute) rewriter.create(loc); - // Step 6. Erase old op. + // Step 8. Erase old op. rewriter.eraseOp(foreachThreadOp); return DiagnosedSilenceableFailure::success(); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1114,12 +1114,15 @@ if (body->getArgument(i + getRank()).getType() != getOutputs()[i].getType()) return emitOpError("type mismatch between ") << i << "-th output and corresponding block argument"; - if (getMapping().has_value()) + if (getMapping().has_value() && !getMapping()->empty()) { + if (static_cast(getMapping()->size()) != getRank()) + return emitOpError() << "mapping attribute size must match op rank"; for (auto map : getMapping()->getValue()) { if (!isa(map)) return emitOpError() << getMappingAttrName() << " is not device mapping attribute"; } + } return success(); } @@ -1294,59 +1297,21 @@ return cast(getBody()->getTerminator()); } -template -static FailureOr> permute(const SmallVector &vals, - ArrayRef perm) { - if (vals.size() != perm.size()) - return failure(); - SmallVector result(vals.size()); - SmallVector seen(vals.size()); - for (auto [idx, val] : llvm::zip(perm, vals)) { - // Already seen, invalid mapping. - if (seen[idx]) - return failure(); - result[idx] = val; - seen[idx] = true; - } - // Some not seen, invalid mapping. - if (!llvm::all_of(seen, [](bool b) { return b; })) - return failure(); - return result; -} - -/// Helper to get apply the `mapping` permutation of a -/// `foreachThreadOp` to `values`. -template -static FailureOr> -getValuesPermutedByThreadMapping(scf::ForeachThreadOp foreachThreadOp, - const SmallVector &values, - ArrayRef mapping) { - // Apply mapping permutation if specified. - FailureOr> maybePermuted = permute(values, mapping); - if (failed(maybePermuted)) - return foreachThreadOp->emitError("invalid permutation"); - return *maybePermuted; - return values; -} - -/// Return the thread indices in the order specified by the mapping -/// attribute. Return failure is mapping is not a valid permutation. -FailureOr> -ForeachThreadOp::getPermutedThreadIndices(ArrayRef mapping) { - SmallVector threadCountValues = this->getThreadIndices(); - threadCountValues.resize(3, Value()); - return getValuesPermutedByThreadMapping(*this, threadCountValues, mapping); -} - -/// Return the number of threads in the order specified by the -/// mapping attribute. -/// Return failure is mapping is not a valid permutation. -FailureOr> -ForeachThreadOp::getPermutedNumThreads(OpBuilder &b, - ArrayRef mapping) { - SmallVector threadCountValues = this->getNumThreads(); - threadCountValues.resize(3, b.getIndexAttr(1)); - return getValuesPermutedByThreadMapping(*this, threadCountValues, mapping); +/// Helper to sort `values` according to matching `keys`. +SmallVector ForeachThreadOp::getValuesSortedByKey( + ArrayRef keys, ValueRange values, + llvm::function_ref compare) { + if (keys.empty()) + return values; + assert(keys.size() == values.size() && "unexpected mismatching sizes"); + auto indices = llvm::to_vector(llvm::seq(0, values.size())); + std::sort(indices.begin(), indices.end(), + [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); }); + SmallVector res; + res.reserve(values.size()); + for (int64_t i = 0, e = indices.size(); i < e; ++i) + res.push_back(values[indices[i]]); + return res; } ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) { diff --git a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir --- a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir @@ -67,7 +67,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } @@ -79,7 +79,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } @@ -106,7 +106,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } return %y : memref<2 x 32 x f32> @@ -131,7 +131,7 @@ scf.foreach_thread (%i, %j, %k, %l) in (%c2, %c32,%c32,%c32) { %4 = memref.load %x[%i, %j, %k, %l] : memref<2x32x32x32xf32> memref.store %4, %y[%i, %j, %k, %l] : memref<2x32x32x32xf32> - } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } + } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread, #gpu.thread] } gpu.terminator } return %y : memref<2x32x32x32xf32> @@ -197,14 +197,14 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } + } { mapping = [#gpu.thread, #gpu.thread] } scf.foreach_thread (%i, %j) in (%c7, %c9) { %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32> %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } @@ -232,14 +232,14 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } + } { mapping = [#gpu.thread, #gpu.thread] } scf.foreach_thread (%i, %j) in (%c7, %c9) { %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32> %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } + } { mapping = [#gpu.thread, #gpu.thread] } return %y : memref<2 x 32 x f32> } @@ -261,7 +261,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } { mapping = [#gpu.block, #gpu.block, #gpu.block] } + } { mapping = [#gpu.block, #gpu.block] } return %y : memref<2 x 32 x f32> } diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir --- a/mlir/test/Dialect/GPU/transform-gpu.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu.mlir @@ -24,7 +24,7 @@ %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } { mapping = [#gpu.block, #gpu.block, #gpu.block]} + } { mapping = [#gpu.block, #gpu.block]} gpu.terminator } return %y : !type @@ -73,12 +73,12 @@ %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread]} + } { mapping = [#gpu.thread, #gpu.thread]} scf.foreach_thread (%i) in (%c12) { %7 = memref.load %t[%i] : !type1d %8 = arith.addf %alpha, %7 : f32 memref.store %8, %t[%i] : !type1d - } {mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } + } {mapping = [#gpu.thread] } gpu.terminator } return %y : !type @@ -118,8 +118,8 @@ %5 = memref.load %y[%i, %j, %k, %l] : !type4d %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j, %k, %l] : !type4d - } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } - } { mapping = [#gpu.block, #gpu.block, #gpu.block] } + } { mapping = [#gpu.thread, #gpu.thread] } + } { mapping = [#gpu.block, #gpu.block] } return %y : !type4d } @@ -151,7 +151,7 @@ %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } return %y : !type diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -575,6 +575,21 @@ // ----- +func.func @mismatched_mapping(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> { + %one = arith.constant 1 : index + %c65535 = arith.constant 65535 : index + // expected-error @below {{'scf.foreach_thread' op mapping attribute size must match op rank}} + scf.foreach_thread (%i, %j) in (%c65535, %c65535) { + %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32> + %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> + %6 = math.fma %alpha, %4, %5 : f32 + memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> + } { mapping = [#gpu.block, #gpu.block, #gpu.block] } + return %y : memref<2 x 32 x f32> +} + +// ----- + func.func @switch_wrong_case_count(%arg0: index) { // expected-error @below {{'scf.index_switch' op has 0 case regions but 1 case values}} "scf.index_switch"(%arg0) ({