Please use GitHub pull requests for new patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Show First 20 Lines • Show All 1,238 Lines • ▼ Show 20 Lines | LogicalResult SubgroupMmaComputeOp::verify() { | ||||
if (aShape[1] != bShape[0] || aShape[0] != cShape[0] || | if (aShape[1] != bShape[0] || aShape[0] != cShape[0] || | ||||
bShape[1] != cShape[1]) | bShape[1] != cShape[1]) | ||||
return emitError("operand shapes do not satisfy matmul constraints"); | return emitError("operand shapes do not satisfy matmul constraints"); | ||||
return success(); | return success(); | ||||
} | } | ||||
/// This is a common class used for patterns of the form | /// This is a common class used for patterns of the form | ||||
/// "someop(memrefcast) -> someop". It folds the source of any memref.cast | /// "someop(memrefcast) -> someop". It folds the source of a memref.cast | ||||
/// into the root operation directly. | /// into the root operation directly if source and user have the same memref | ||||
/// element type. | |||||
static LogicalResult foldMemRefCast(Operation *op) { | static LogicalResult foldMemRefCast(Operation *op) { | ||||
bool folded = false; | bool folded = false; | ||||
// Returns true iff memref types have the same element type. | |||||
auto sameEltType = [](Value a, Value b) -> bool { | |||||
auto aEltType = a.getType().cast<MemRefType>().getElementType(); | |||||
auto bEltType = b.getType().cast<MemRefType>().getElementType(); | |||||
return aEltType == bEltType; | |||||
}; | |||||
for (OpOperand &operand : op->getOpOperands()) { | for (OpOperand &operand : op->getOpOperands()) { | ||||
auto cast = operand.get().getDefiningOp<mlir::memref::CastOp>(); | auto cast = operand.get().getDefiningOp<mlir::memref::CastOp>(); | ||||
if (cast) { | if (!cast || !sameEltType(operand.get(), cast.getOperand())) | ||||
continue; | |||||
operand.set(cast.getOperand()); | operand.set(cast.getOperand()); | ||||
folded = true; | folded = true; | ||||
} | } | ||||
} | |||||
return success(folded); | return success(folded); | ||||
} | } | ||||
LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands, | LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands, | ||||
SmallVectorImpl<::mlir::OpFoldResult> &results) { | SmallVectorImpl<::mlir::OpFoldResult> &results) { | ||||
return foldMemRefCast(*this); | return foldMemRefCast(*this); | ||||
} | } | ||||
▲ Show 20 Lines • Show All 143 Lines • Show Last 20 Lines |